Re-wrote net transfer

This commit is contained in:
nsde 2023-07-25 19:45:21 +02:00
parent 605fbc51a3
commit ee5b8561df
6 changed files with 85 additions and 29 deletions

13
api/apihandler.py Normal file
View file

@ -0,0 +1,13 @@
from typing import Union, Optional
class Request:
def __init__(self,
method: str,
url: str,
json_payload: Optional[Union[dict, list]]=None,
headers: dict=None
):
self.method = method
self.url = url
self.json = json_payload
self.headers = headers or {}

View file

@ -1,18 +1,23 @@
import os
import aiohttp
import proxies
from dotenv import load_dotenv
from request_manager import Request
load_dotenv()
async def receive_target_stream():
async def stream_closedai_request(request: Request):
async with aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=int(os.getenv('TRANSFER_TIMEOUT', '120'))),
connector=await proxies.default_proxy.get_connector(),
timeout=aiohttp.ClientTimeout(total=request.timeout),
raise_for_status=False
) as session:
async with session.request(
method=incoming_request.method,
url=target_url,
json=incoming_json_payload,
method=request.method,
url=request.url,
json=request.payload,
headers={
'Content-Type': 'application/json',
'Authorization': f'Bearer {os.getenv("CLOSEDAI_KEY")}'
@ -20,5 +25,7 @@ async def receive_target_stream():
) as response:
async for chunk in response.content.iter_any():
chunk = f'{chunk.decode("utf8")}\n\n'
yield chunk
if __name__ == '__main__':
pass

View file

@ -21,12 +21,14 @@ class Proxy:
password: str=None
):
self.proxy_type = proxy_type
self.ip_address = socket.gethostbyname(host)
self.host = host
self.ip_address = host
self.host = socket.gethostbyname(host)
self.port = port
self.username = username
self.password = password
self.url = f'socks5://{self.username}:{self.password}@{self.ip_address}:{self.port}'
async def initialize_connector(self, connector):
async with aiohttp.ClientSession(
connector=connector,
@ -61,13 +63,8 @@ class Proxy:
await self.initialize_connector(connector)
# Logging to check the connector state
print("Connector: Is closed?", connector.closed)
print("Connector: Is connected?", connector._connected)
return connector
default_proxy = Proxy(
proxy_type=os.getenv('PROXY_TYPE', 'http'),
host=os.getenv('PROXY_HOST', '127.0.0.1'),
@ -75,3 +72,18 @@ default_proxy = Proxy(
username=os.getenv('PROXY_USER'),
password=os.getenv('PROXY_PASS')
)
if __name__ == '__main__':
import requests
print(default_proxy.url)
received_ip = requests.get(
'https://checkip.amazonaws.com',
timeout=5,
proxies={
'https': default_proxy.url
}
).text.strip()
print(received_ip)

28
api/request_manager.py Normal file
View file

@ -0,0 +1,28 @@
import os
from dotenv import load_dotenv
from typing import Union, Optional
load_dotenv()
EXCLUDED_HEADERS = [
'content-encoding',
'content-length',
'transfer-encoding',
'connection'
]
class Request:
def __init__(self,
url: str,
method: str='GET',
payload: Optional[Union[dict, list]]=None,
headers: dict={
'Content-Type': 'application/json'
}
):
self.method = method.upper()
self.url = url.replace('/v1/v1', '/v1')
self.payload = payload
self.headers = headers
self.timeout = int(os.getenv('TRANSFER_TIMEOUT', '120'))

View file

@ -2,17 +2,14 @@
import os
import json
import aiohttp
import logging
import starlette
import netclient
import proxies
import netclient
import request_manager
from dotenv import load_dotenv
from starlette.background import StreamingResponse
load_dotenv()
# log to "api.log" file
@ -24,19 +21,12 @@ logging.basicConfig(
logging.info('API started')
EXCLUDED_HEADERS = [
'content-encoding',
'content-length',
'transfer-encoding',
'connection'
]
async def handle_api_request(incoming_request, target_endpoint: str=''):
"""Transfer a streaming response from the incoming request to the target endpoint"""
if not target_endpoint:
target_endpoint = os.getenv('CLOSEDAI_ENDPOINT')
target_url = f'{target_endpoint}{incoming_request.url.path}'.replace('/v1/v1', '/v1')
target_url = f'{target_endpoint}{incoming_request.url.path}'
logging.info('TRANSFER %s -> %s', incoming_request.url.path, target_url)
if target_url.endswith('/v1'):
@ -55,6 +45,12 @@ async def handle_api_request(incoming_request, target_endpoint: str=''):
if 'temperature' in payload or 'functions' in payload:
target_provider = 'closed'
return StreamingResponse(
content=netclient.receive_target_stream()
request = request_manager.Request(
url=target_url,
payload=payload,
method=incoming_request.method,
)
return starlette.responses.StreamingResponse(
content=netclient.stream_closedai_request(request)
)