diff --git a/api/helpers/chat.py b/api/helpers/chat.py index 2526241..7dccb24 100644 --- a/api/helpers/chat.py +++ b/api/helpers/chat.py @@ -1,3 +1,4 @@ +import json import string import random import asyncio @@ -26,11 +27,14 @@ def create_chat_chunk(chat_id: str, model: str, content=None) -> dict: 'content': content } - if not isinstance(content, str): + if content == CompletionStart: delta = { 'role': 'assistant' } + if content == CompletionStop: + delta = {} + chunk = { 'id': chat_id, 'object': 'chat.completion.chunk', @@ -40,13 +44,12 @@ def create_chat_chunk(chat_id: str, model: str, content=None) -> dict: { 'delta': delta, 'index': 0, - 'finish_reason': None if not(isinstance(content, str)) else 'stop' + 'finish_reason': 'stop' if content == CompletionStop else None } ], } - print(chunk) - return chunk + return f'data: {json.dumps(chunk)}\n\n' if __name__ == '__main__': demo_chat_id = asyncio.run(create_chat_id()) diff --git a/api/helpers/network.py b/api/helpers/network.py index 26d810c..76a3bad 100644 --- a/api/helpers/network.py +++ b/api/helpers/network.py @@ -2,7 +2,19 @@ import base64 import asyncio async def get_ip(request) -> str: - return request.client.host + xff = None + if request.headers.get('x-forwarded-for'): + xff, *_ = request.headers['x-forwarded-for'].split(', ') + + possible_ips = [ + xff, + request.headers.get('cf-connecting-ip'), + request.client.host + ] + + detected_ip = next((i for i in possible_ips if i), None) + + return detected_ip async def add_proxy_auth_to_headers(username: str, password: str, headers: dict) -> dict: proxy_auth = base64.b64encode(f'{username}:{password}'.encode()).decode() diff --git a/api/load_balancing.py b/api/load_balancing.py index 6fa4696..263dfe2 100644 --- a/api/load_balancing.py +++ b/api/load_balancing.py @@ -4,11 +4,11 @@ import asyncio import chat_providers provider_modules = [ - chat_providers.twa, + # chat_providers.twa, # chat_providers.quantum, - # chat_providers.churchless, - # chat_providers.closed, - # chat_providers.closed4 + chat_providers.churchless, + chat_providers.closed, + chat_providers.closed4 ] def _get_module_name(module) -> str: @@ -29,6 +29,9 @@ async def balance_chat_request(payload: dict) -> dict: providers_available.append(provider_module) + if not providers_available: + raise NotImplementedError('This model does not exist.') + provider = random.choice(providers_available) target = provider.chat_completion(**payload) target['module'] = _get_module_name(provider) diff --git a/api/main.py b/api/main.py index cb593fe..b177efa 100644 --- a/api/main.py +++ b/api/main.py @@ -2,9 +2,9 @@ import fastapi -from fastapi.middleware.cors import CORSMiddleware - +from rich import print from dotenv import load_dotenv +from fastapi.middleware.cors import CORSMiddleware import core import transfer diff --git a/api/proxies.py b/api/proxies.py index ee7348d..639607a 100644 --- a/api/proxies.py +++ b/api/proxies.py @@ -7,6 +7,7 @@ import asyncio import aiohttp import aiohttp_socks +from rich import print from dotenv import load_dotenv load_dotenv() @@ -71,14 +72,17 @@ class Proxy: proxies_in_files = [] -for proxy_type in ['http', 'socks4', 'socks5']: - with open(f'secret/proxies/{proxy_type}.txt') as f: - for line in f.readlines(): - if line.strip() and not line.strip().startswith('#'): - if '#' in line: - line = line.split('#')[0] +try: + for proxy_type in ['http', 'socks4', 'socks5']: + with open(f'secret/proxies/{proxy_type}.txt') as f: + for line in f.readlines(): + if line.strip() and not line.strip().startswith('#'): + if '#' in line: + line = line.split('#')[0] - proxies_in_files.append(f'{proxy_type}://{line.strip()}') + proxies_in_files.append(f'{proxy_type}://{line.strip()}') +except FileNotFoundError: + pass class ProxyChain: def __init__(self): @@ -87,7 +91,11 @@ class ProxyChain: self.get_random = Proxy(url=random_proxy) self.connector = aiohttp_socks.ChainProxyConnector.from_urls(proxies_in_files) -default_chain = ProxyChain() +try: + default_chain = ProxyChain() + random_proxy = ProxyChain().get_random +except IndexError: + pass default_proxy = Proxy( proxy_type=os.getenv('PROXY_TYPE', 'http'), @@ -97,7 +105,6 @@ default_proxy = Proxy( password=os.getenv('PROXY_PASS') ) -random_proxy = ProxyChain().get_random def test_httpx_workaround(): import httpx @@ -129,24 +136,11 @@ async def test_aiohttp_socks(): async def streaming_aiohttp_socks(): async with aiohttp.ClientSession(connector=default_proxy.connector) as session: - async with session.post( - 'https://free.churchless.tech/v1/chat/completions', - json={ - "model": "gpt-3.5-turbo", - "messages": [ - { - "role": "user", - "content": "Hi" - } - ], - "stream": True - }, - # headers={ - # 'Authorization': 'Bearer MyDiscord' - # } - ) as response: - html = await response.text() - return html.strip() + async with session.get('https://httpbin.org/get', headers={ + 'Authorization': 'x' + }) as response: + json = await response.json() + return json async def text_httpx_socks(): import httpx @@ -163,5 +157,5 @@ if __name__ == '__main__': # print(test_httpx()) # print(test_requests()) # print(asyncio.run(test_aiohttp_socks())) - # print(asyncio.run(streaming_aiohttp_socks())) - print(asyncio.run(text_httpx_socks())) + print(asyncio.run(streaming_aiohttp_socks())) + # print(asyncio.run(text_httpx_socks())) diff --git a/api/streaming.py b/api/streaming.py index a82798c..df3011e 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -7,6 +7,7 @@ import starlette from rich import print from dotenv import load_dotenv +from python_socks._errors import ProxyError import proxies import load_balancing @@ -38,6 +39,7 @@ async def stream( input_tokens: int=0, incoming_request: starlette.requests.Request=None, ): + payload = payload or DEMO_PAYLOAD is_chat = False @@ -46,95 +48,114 @@ async def stream( chat_id = await chat.create_chat_id() model = payload['model'] - chat_chunk = chat.create_chat_chunk( + yield chat.create_chat_chunk( chat_id=chat_id, model=model, content=chat.CompletionStart ) - data = json.dumps(chat_chunk) - chunk = f'data: {data}' - - yield chunk + yield chat.create_chat_chunk( + chat_id=chat_id, + model=model, + content=None + ) for _ in range(5): - if is_chat: - target_request = await load_balancing.balance_chat_request(payload) - else: - target_request = await load_balancing.balance_organic_request(payload) - headers = { 'Content-Type': 'application/json' } + if is_chat: + target_request = await load_balancing.balance_chat_request(payload) + else: + target_request = await load_balancing.balance_organic_request({ + 'path': path, + 'payload': payload, + 'headers': headers + }) + for k, v in target_request.get('headers', {}).items(): headers[k] = v - async with aiohttp.ClientSession(connector=proxies.random_proxy.connector) as session: - async with session.request( - method=target_request.get('method', 'POST'), - url=target_request['url'], + async with aiohttp.ClientSession(connector=proxies.default_proxy.connector) as session: - data=target_request.get('data'), - json=target_request.get('payload'), + try: + async with session.request( + method=target_request.get('method', 'POST'), + url=target_request['url'], - headers=headers, - cookies=target_request.get('cookies'), + data=target_request.get('data'), + json=target_request.get('payload'), - ssl=False, + headers=headers, + cookies=target_request.get('cookies'), - timeout=aiohttp.ClientTimeout(total=float(os.getenv('TRANSFER_TIMEOUT', '120'))), - ) as response: - try: - await response.raise_for_status() - except Exception as exc: - continue - # if 'Too Many Requests' in str(exc): + ssl=False, - if user and incoming_request: - await logs.log_api_request( - user=user, - incoming_request=incoming_request, - target_url=target_request['url'] - ) + timeout=aiohttp.ClientTimeout(total=float(os.getenv('TRANSFER_TIMEOUT', '120'))), + ) as response: - if credits_cost and user: - await users.update_by_id(user['_id'], { - '$inc': {'credits': -credits_cost} - }) + try: + response.raise_for_status() + except Exception as exc: + if 'Too Many Requests' in str(exc): + print(429) + continue - if not demo_mode: - ip_address = await network.get_ip(incoming_request) + if user and incoming_request: + await logs.log_api_request( + user=user, + incoming_request=incoming_request, + target_url=target_request['url'] + ) - await stats.add_date() - await stats.add_ip_address(ip_address) - await stats.add_path(path) - await stats.add_target(target_request['url']) + if credits_cost and user: + await users.update_by_id(user['_id'], { + '$inc': {'credits': -credits_cost} + }) - if is_chat: - await stats.add_model(model) - await stats.add_tokens(input_tokens, model) - async for chunk in response.content.iter_any(): - chunk = f'{chunk.decode("utf8")}\n\n' + try: + async for chunk in response.content.iter_any(): + chunk = f'{chunk.decode("utf8")}\n\n' + + if chunk.strip(): + if is_chat: + if target_request['module'] == 'twa': + data = json.loads(chunk.split('data: ')[1]) + + if data.get('text'): + chunk = chat.create_chat_chunk( + chat_id=chat_id, + model=model, + content=['text'] + ) + yield chunk + + except Exception as exc: + if 'Connection closed' in str(exc): + print('connection closed') + continue + + if not demo_mode: + ip_address = await network.get_ip(incoming_request) + + await stats.add_date() + await stats.add_ip_address(ip_address) + await stats.add_path(path) + await stats.add_target(target_request['url']) - if chunk.strip(): if is_chat: - if target_request['module'] == 'twa': - data = json.loads(chunk.split('data: ')[1]) + await stats.add_model(model) + await stats.add_tokens(input_tokens, model) - if data.get('text'): - chat_chunk = chat.create_chat_chunk( - chat_id=chat_id, - model=model, - content=['text'] - ) - data = json.dumps(chat_chunk) + break - chunk = f'data: {data}' + except ProxyError: + print('proxy error') + continue - yield chunk - break + print(3) if is_chat: chat_chunk = chat.create_chat_chunk( chat_id=chat_id, @@ -143,8 +164,7 @@ async def stream( ) data = json.dumps(chat_chunk) - yield f'data: {data}' - yield 'data: [DONE]' + yield 'data: [DONE]\n\n' if __name__ == '__main__': asyncio.run(stream()) diff --git a/tests/__main__.py b/tests/__main__.py index 9f4db02..80b0792 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -86,8 +86,27 @@ def test_all(): # print(test_api()) print(test_library()) + +def test_api(model: str=MODEL, messages: List[dict]=None) -> dict: + """Tests an API api_endpoint.""" + + headers = { + 'Authorization': 'Bearer ' + api_key + } + + response = httpx.get( + url=f'{api_endpoint}/v1/usage', + headers=headers, + timeout=20 + ) + response.raise_for_status() + + return response.text + if __name__ == '__main__': # api_endpoint = 'https://api.nova-oss.com' - api_endpoint = 'http://localhost:2332' + api_endpoint = 'https://alpha-api.nova-oss.com' api_key = os.getenv('TEST_NOVA_KEY') - test_all() + # test_all() + + print(test_api())