diff --git a/README.md b/README.md index 8157536..64ef158 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # ☄️ Nova API Server -Reverse proxy server for OpenAI's API. +Reverse proxy server for "Closed"AI's API. ## Install Assuming you have a new version of Python 3 and pip installed: @@ -28,7 +28,7 @@ pip install . ## `.env` configuration ### `ACTUAL_IPS` (optional) -This is a security measure to make sure a proxy, VPN, Tor or any other IP hiding service is used by the host when accessing OpenAI's API. +This is a security measure to make sure a proxy, VPN, Tor or any other IP hiding service is used by the host when accessing "Closed"AI's API. It is a space separated list of IP addresses that are allowed to access the API. You can also just add the *beginning* of an API address, like `12.123.` to allow all IPs starting with `12.123.`. diff --git a/api/main.py b/api/main.py index ed04c66..9774aa7 100644 --- a/api/main.py +++ b/api/main.py @@ -1,19 +1,16 @@ import os -import httpx import fastapi -from keys import Keys -from starlette.requests import Request from starlette.responses import StreamingResponse -from starlette.background import BackgroundTask +from starlette.requests import Request from fastapi.middleware.cors import CORSMiddleware from dotenv import load_dotenv import security +import transfer load_dotenv() -target_api_client = httpx.AsyncClient(base_url='https://api.openai.com/') app = fastapi.FastAPI() @@ -32,9 +29,6 @@ async def startup_event(): security.enable_proxy() security.ip_protection_check() - # Setup key cache - Keys() - @app.get('/') async def root(): """Returns the root endpoint.""" @@ -46,34 +40,12 @@ async def root(): } async def _reverse_proxy(request: Request): - target_url = f'https://api.openai.com/v1/{request.url.path}' - key = Keys.get(request.body()['model']) - if not key: - return fastapi.responses.JSONResponse( - status_code=400, - content={ - 'error': 'No API Key for model given, please try again with a valid model.' - } - ) - request_to_api = target_api_client.build_request( - method=request.method, - url=target_url, - headers={ - 'Authorization': 'Bearer ' + key, - 'Content-Type': 'application/json' - }, - content=await request.body(), - ) + headers = { + name: value + for name, value in target_response.headers.items() + if name.lower() not in EXCLUDED_HEADERS + } - api_response = await target_api_client.send(request_to_api, stream=True) - - print(f'[{request.method}] {request.url.path} {api_response.status_code}') - Keys(key).unlock() - return StreamingResponse( - api_response.aiter_raw(), - status_code=api_response.status_code, - headers=api_response.headers, - background=BackgroundTask(api_response.aclose) - ) + # ... app.add_route('/{path:path}', _reverse_proxy, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']) diff --git a/api/proxies.py b/api/proxies.py new file mode 100644 index 0000000..56c2768 --- /dev/null +++ b/api/proxies.py @@ -0,0 +1,119 @@ +"""This module contains the Proxy class, which represents a proxy.""" + +import os +import httpx +import socket +import asyncio + +from sockslib import socks +from rich import print +from dotenv import load_dotenv + +load_dotenv() + +class Proxy: + """Represents a proxy. The type can be either http, https, socks4 or socks5.""" + + def __init__(self, + proxy_type: str='http', + host: str='127.0.0.1', + port: int=8080, + username: str=None, + password: str=None + ): + self.proxy_type = proxy_type + self.ip_address = socket.gethostbyname(host) + self.host = host + self.port = port + self.username = username + self.password = password + + @property + def auth(self): + """Returns the authentication part of the proxy URL, if the proxy has a username and password.""" + return f'{self.username}:{self.password}@' if all([self.username, self.password]) else '' + + @property + def protocol(self): + """Makes sure the hostnames are resolved correctly using the proxy. + See https://stackoverflow.com/a/43266186 + """ + return self.proxy_type# + 'h' if self.proxy_type.startswith('socks') else self.proxy_type + + def proxies(self): + """Returns a dictionary of proxies, ready to be used with the requests library or httpx. + """ + + url = f'{self.protocol}://{self.auth}{self.host}:{self.port}' + + proxies_dict = { + 'http://': url.replace('https', 'http') if self.proxy_type == 'https' else url, + 'https://': url.replace('http', 'https') if self.proxy_type == 'http' else url + } + + return proxies_dict + + def __str__(self): + return f'{self.proxy_type}://{len(self.auth) * "*"}{self.host}:{self.port}' + + def __repr__(self): + return f'' + +active_proxy = Proxy( + proxy_type=os.getenv('PROXY_TYPE', 'http'), + host=os.getenv('PROXY_HOST', '127.0.0.1'), + port=int(os.getenv('PROXY_PORT', 8080)), + username=os.getenv('PROXY_USER'), + password=os.getenv('PROXY_PASS') +) + +def activate_proxy() -> None: + socks.set_default_proxy( + proxy_type=socks.PROXY_TYPES[active_proxy.proxy_type.upper()], + addr=active_proxy.host, + port=active_proxy.port, + username=active_proxy.username, + password=active_proxy.password + ) + socket.socket = socks.socksocket + +def check_proxy(): + """Checks if the proxy is working.""" + + resp = httpx.get( + url='https://echo.hoppscotch.io/', + timeout=20, + proxies=active_proxy.proxies() + ) + resp.raise_for_status() + + return resp.ok + +async def check_api(): + model = 'gpt-3.5-turbo' + messages = [ + { + 'role': 'user', + 'content': 'Explain what a wormhole is.' + }, + ] + + headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + os.getenv('CLOSEDAI_KEY') + } + + json_data = { + 'model': model, + 'messages': messages, + 'stream': True + } + + async with httpx.AsyncClient(timeout=20) as client: + async with client.stream("POST", 'https://api.openai.com/v1/chat/completions', headers=headers, json=json_data) as response: + response.raise_for_status() + async for chunk in response.aiter_text(): + print(chunk.strip()) + +if __name__ == '__main__': + asyncio.run(check_api()) diff --git a/api/security.py b/api/security.py index 85c8e52..8ad8207 100644 --- a/api/security.py +++ b/api/security.py @@ -1,40 +1,25 @@ import os -import socket import httpx -from sockslib import socks from rich import print +import proxies + +from dotenv import load_dotenv + +load_dotenv() is_proxy_enabled = False def enable_proxy(): - """Enables the SOCKS5 proxy.""" + """Enables the proxy.""" global is_proxy_enabled - if all([os.getenv('PROXY_HOST'), os.getenv('PROXY_PORT')]): - proxy_type = socks.PROXY_TYPE_HTTP + proxies.activate_proxy() - if '4' in os.getenv('PROXY_TYPE'): - proxy_type = socks.PROXY_TYPE_SOCKS4 + print(f'[green]SUCCESS: Proxy enabled: {proxies.active_proxy}[/green]') - if '5' in os.getenv('PROXY_TYPE'): - proxy_type = socks.PROXY_TYPE_SOCKS5 - - socks.set_default_proxy( - proxy_type=proxy_type, - addr=os.getenv('PROXY_HOST'), - port=int(os.getenv('PROXY_PORT')), - username=os.getenv('PROXY_USER'), - password=os.getenv('PROXY_PASS') - ) - socket.socket = socks.socksocket - - is_proxy_enabled = True - - else: - print('[yellow]WARNING: PROXY_PORT, PROXY_IP, PROXY_USER, and PROXY_PASS are not set in the .env file or empty. \ -Consider configuring a SOCKS5 proxy to improve the security.[/yellow]') + is_proxy_enabled = True class InsecureIPError(Exception): """Raised when the IP address of the server is not secure.""" @@ -45,17 +30,27 @@ def ip_protection_check(): actual_ips = os.getenv('ACTUAL_IPS', '').split() if actual_ips: - detected_ip = httpx.get('https://checkip.amazonaws.com', timeout=5).text.strip() + echo_response = httpx.get( + url='https://echo.hoppscotch.io/', + timeout=15 + ) + + response_data = echo_response.json() + response_ip = response_data['headers']['x-forwarded-for'] for actual_ip in actual_ips: - if detected_ip.startswith(actual_ip): - raise InsecureIPError(f'IP {detected_ip} is in the values of ACTUAL_IPS of the\ + if actual_ip in response_data: + raise InsecureIPError(f'IP pattern "{actual_ip}" is in the values of ACTUAL_IPS of the\ .env file. Enable a VPN or proxy to continue.') if is_proxy_enabled: - print(f'[green]SUCCESS: The IP {detected_ip} was detected, which seems to be a proxy. Great![/green]') + print(f'[green]SUCCESS: The IP "{response_ip}" was detected, which seems to be a proxy. Great![/green]') else: print('[yellow]WARNING: ACTUAL_IPS is not set in the .env file or empty.\ This means that the real IP of the server could be exposed. If you\'re using something\ like Cloudflare or Repl.it, you can ignore this warning.[/yellow]') + +if __name__ == '__main__': + enable_proxy() + ip_protection_check() diff --git a/api/sockslib/socks.py b/api/sockslib/socks.py index e2d5e5a..19b2b91 100644 --- a/api/sockslib/socks.py +++ b/api/sockslib/socks.py @@ -1,8 +1,8 @@ """ -SocksiPy - Python SOCKS module. -Version 1.5.1 +THE FOLLOWING CODE WAS TAKEN FROM https://raw.githubusercontent.com/m0rtem/CloudFail/master/socks.py -Code from https://github.com/XX-net/XX-Net/blob/master/code/default/lib/noarch/socks.py +SocksiPy - Python SOCKS module. +Version 1.5.7 Copyright 2006 Dan-Haim. All rights reserved. @@ -54,44 +54,28 @@ Modifications made by Anorov (https://github.com/Anorov) -Various small bug fixes """ -__version__ = "1.5.1" +__version__ = "1.5.7" -import os, sys -from base64 import b64encode import socket import struct from errno import EOPNOTSUPP, EINVAL, EAGAIN -from io import BytesIO, SEEK_CUR +from io import BytesIO +from os import SEEK_CUR +from base64 import b64encode try: - from collections import Callable -except: from collections.abc import Callable - -from six import string_types - -current_path = os.path.dirname(os.path.abspath(__file__)) -python_path = os.path.abspath( os.path.join(current_path, os.pardir, os.pardir)) -if sys.platform == "win32": - win32_lib = os.path.abspath( os.path.join(python_path, 'lib', 'win32')) - sys.path.append(win32_lib) - import win_inet_pton - inet_pton = win_inet_pton.inet_pton - inet_ntop = win_inet_pton.inet_ntop -else: - inet_pton = socket.inet_pton - inet_ntop = socket.inet_ntop - -from . import utils +except ImportError: + from collections import Callable PROXY_TYPE_SOCKS4 = SOCKS4 = 1 PROXY_TYPE_SOCKS5 = SOCKS5 = 2 PROXY_TYPE_HTTP = HTTP = 3 -PRINTABLE_PROXY_TYPES = {SOCKS4: "SOCKS4", SOCKS5: "SOCKS5", HTTP: "HTTP"} +PROXY_TYPES = {"SOCKS4": SOCKS4, "SOCKS5": SOCKS5, "HTTP": HTTP} +PRINTABLE_PROXY_TYPES = dict(zip(PROXY_TYPES.values(), PROXY_TYPES.keys())) _orgsocket = _orig_socket = socket.socket - class ProxyError(IOError): """ socket_err contains original socket.error exception. @@ -106,11 +90,6 @@ class ProxyError(IOError): def __str__(self): return self.msg - def __repr__(self): - # for %r - return repr(self.msg) - - class GeneralProxyError(ProxyError): pass class ProxyConnectionError(ProxyError): pass class SOCKS5AuthError(ProxyError): pass @@ -118,7 +97,6 @@ class SOCKS5Error(ProxyError): pass class SOCKS4Error(ProxyError): pass class HTTPError(ProxyError): pass - SOCKS4_ERRORS = { 0x5B: "Request rejected or failed", 0x5C: "Request rejected because SOCKS server cannot connect to identd on the client", 0x5D: "Request rejected because the client program and identd report different user-ids" @@ -139,7 +117,6 @@ DEFAULT_PORTS = { SOCKS4: 1080, HTTP: 8080 } - def set_default_proxy(proxy_type=None, addr=None, port=None, rdns=True, username=None, password=None): """ set_default_proxy(proxy_type, addr[, port[, rdns[, username, password]]]) @@ -147,44 +124,20 @@ def set_default_proxy(proxy_type=None, addr=None, port=None, rdns=True, username Sets a default proxy which all further socksocket objects will use, unless explicitly changed. All parameters are as for socket.set_proxy(). """ - proxy_type = utils.bytes2str_only(proxy_type) - addr = utils.to_str(addr) - if isinstance(port, bytes): - port = int(utils.to_str(port)) - else: - port = int(port) - username = utils.to_bytes(username) - password = utils.to_bytes(password) - - if isinstance(proxy_type, str): - proxy_type = proxy_type.lower() - if "http" in proxy_type: - proxy_type = PROXY_TYPE_HTTP - elif "socks5" in proxy_type: - proxy_type = PROXY_TYPE_SOCKS5 - elif "socks4" in proxy_type: - proxy_type = PROXY_TYPE_SOCKS4 - else: - raise ProxyError("unknown proxy type:%s" % proxy_type) - socksocket.default_proxy = (proxy_type, addr, port, rdns, - username if username else None, - password if password else None) - + username.encode() if username else None, + password.encode() if password else None) setdefaultproxy = set_default_proxy - def get_default_proxy(): """ Returns the default proxy, set by set_default_proxy. """ return socksocket.default_proxy - getdefaultproxy = get_default_proxy - def wrap_module(module): """ Attempts to replace a module's socket library with a SOCKS socket. Must set @@ -199,27 +152,64 @@ def wrap_module(module): wrapmodule = wrap_module - def create_connection(dest_pair, proxy_type=None, proxy_addr=None, - proxy_port=None, proxy_username=None, - proxy_password=None, timeout=None): + proxy_port=None, proxy_rdns=True, + proxy_username=None, proxy_password=None, + timeout=None, source_address=None, + socket_options=None): """create_connection(dest_pair, *[, timeout], **proxy_args) -> socket object Like socket.create_connection(), but connects to proxy before returning the socket object. dest_pair - 2-tuple of (IP/hostname, port). - **proxy_args - Same args passed to socksocket.set_proxy(). + **proxy_args - Same args passed to socksocket.set_proxy() if present. timeout - Optional socket timeout value, in seconds. + source_address - tuple (host, port) for the socket to bind to as its source + address before connecting (only for compatibility) """ - sock = socksocket() - if isinstance(timeout, (int, float)): - sock.settimeout(timeout) - sock.set_proxy(proxy_type, proxy_addr, proxy_port, - proxy_username, proxy_password) - sock.connect(dest_pair) - return sock + # Remove IPv6 brackets on the remote address and proxy address. + remote_host, remote_port = dest_pair + if remote_host.startswith('['): + remote_host = remote_host.strip('[]') + if proxy_addr and proxy_addr.startswith('['): + proxy_addr = proxy_addr.strip('[]') + err = None + + # Allow the SOCKS proxy to be on IPv4 or IPv6 addresses. + for r in socket.getaddrinfo(proxy_addr, proxy_port, 0, socket.SOCK_STREAM): + family, socket_type, proto, canonname, sa = r + sock = None + try: + sock = socksocket(family, socket_type, proto) + + if socket_options is not None: + for opt in socket_options: + sock.setsockopt(*opt) + + if isinstance(timeout, (int, float)): + sock.settimeout(timeout) + + if proxy_type is not None: + sock.set_proxy(proxy_type, proxy_addr, proxy_port, proxy_rdns, + proxy_username, proxy_password) + if source_address is not None: + sock.bind(source_address) + + sock.connect((remote_host, remote_port)) + return sock + + except socket.error as e: + err = e + if sock is not None: + sock.close() + sock = None + + if err is not None: + raise err + + raise socket.error("gai returned empty list.") class _BaseSocket(socket.socket): """Allows Python 2's "delegated" methods such as send() to be overridden @@ -234,11 +224,8 @@ class _BaseSocket(socket.socket): _savenames = list() - def _makemethod(name): return lambda self, *pos, **kw: self._savedmethods[name](*pos, **kw) - - for name in ("sendto", "send", "recvfrom", "recv"): method = getattr(_BaseSocket, name, None) @@ -250,7 +237,6 @@ for name in ("sendto", "send", "recvfrom", "recv"): _BaseSocket._savenames.append(name) setattr(_BaseSocket, name, _makemethod(name)) - class socksocket(_BaseSocket): """socksocket([family[, type[, proto]]]) -> socket object @@ -262,27 +248,18 @@ class socksocket(_BaseSocket): default_proxy = None - def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, _sock=None): - if type not in {socket.SOCK_STREAM, socket.SOCK_DGRAM}: + def __init__(self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, *args, **kwargs): + if type not in (socket.SOCK_STREAM, socket.SOCK_DGRAM): msg = "Socket type must be stream or datagram, not {!r}" raise ValueError(msg.format(type)) + _BaseSocket.__init__(self, family, type, proto, *args, **kwargs) self._proxyconn = None # TCP connection to keep UDP relay alive - self.resolve_dest = True if self.default_proxy: self.proxy = self.default_proxy - proxy_host = self.proxy[1] - if utils.check_ip_valid6(proxy_host): - family=socket.AF_INET6 - elif utils.check_ip_valid4(proxy_host): - family=socket.AF_INET - else: self.proxy = (None, None, None, None, None, None) - - _BaseSocket.__init__(self, family, type, proto, _sock) - self.proxy_sockname = None self.proxy_peername = None @@ -317,34 +294,9 @@ class socksocket(_BaseSocket): password - Password to authenticate with to the server. Only relevant when username is also provided. """ - - proxy_type = utils.bytes2str_only(proxy_type) - addr = utils.to_str(addr) - if isinstance(port, bytes): - port = int(utils.to_str(port)) - else: - port = int(port) - username = utils.to_bytes(username) - password = utils.to_bytes(password) - - if isinstance(proxy_type, string_types): - proxy_type = proxy_type.lower() - if "http" in proxy_type: - proxy_type = PROXY_TYPE_HTTP - self.resolve_dest = False - elif "socks5" in proxy_type: - if proxy_type == "socks5h": - self.resolve_dest = False - rdns = True - proxy_type = PROXY_TYPE_SOCKS5 - elif "socks4" in proxy_type: - proxy_type = PROXY_TYPE_SOCKS4 - else: - raise ProxyError("unknown proxy type:%s" % proxy_type) - self.proxy = (proxy_type, addr, port, rdns, - username if username else None, - password if password else None) + username.encode() if username else None, + password.encode() if password else None) setproxy = set_proxy @@ -559,26 +511,39 @@ class socksocket(_BaseSocket): and the resolved address as a tuple object. """ host, port = addr - host = utils.to_str(host) proxy_type, _, _, rdns, username, password = self.proxy + family_to_byte = {socket.AF_INET: b"\x01", socket.AF_INET6: b"\x04"} - if utils.check_ip_valid6(host): - addr_bytes = inet_pton(socket.AF_INET6, host) - file.write(b"\x04" + addr_bytes) - elif utils.check_ip_valid4(host): - addr_bytes = socket.inet_aton(host) - file.write(b"\x01" + addr_bytes) + # If the given destination address is an IP address, we'll + # use the IP address request even if remote resolving was specified. + # Detect whether the address is IPv4/6 directly. + for family in (socket.AF_INET, socket.AF_INET6): + try: + addr_bytes = socket.inet_pton(family, host) + file.write(family_to_byte[family] + addr_bytes) + host = socket.inet_ntop(family, addr_bytes) + file.write(struct.pack(">H", port)) + return host, port + except socket.error: + continue + + # Well it's not an IP number, so it's probably a DNS name. + if rdns: + # Resolve remotely + host_bytes = host.encode('idna') + file.write(b"\x03" + chr(len(host_bytes)).encode() + host_bytes) else: - if rdns: - # Resolve remotely - host_bytes = host.encode("utf-8") - file.write(b"\x03" + chr(len(host_bytes)).encode() + host_bytes) - else: - # Resolve locally - addr_bytes = socket.inet_aton(socket.gethostbyname(host)) - file.write(b"\x01" + addr_bytes) - host = socket.inet_ntoa(addr_bytes) + # Resolve locally + addresses = socket.getaddrinfo(host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, socket.IPPROTO_TCP, socket.AI_ADDRCONFIG) + # We can't really work out what IP is reachable, so just pick the + # first. + target_addr = addresses[0] + family = target_addr[0] + host = target_addr[4][0] + addr_bytes = socket.inet_pton(family, host) + file.write(family_to_byte[family] + addr_bytes) + host = socket.inet_ntop(family, addr_bytes) file.write(struct.pack(">H", port)) return host, port @@ -590,7 +555,7 @@ class socksocket(_BaseSocket): length = self._readall(file, 1) addr = self._readall(file, ord(length)) elif atyp == b"\x04": - addr = inet_ntop(socket.AF_INET6, self._readall(file, 16)) + addr = socket.inet_ntop(socket.AF_INET6, self._readall(file, 16)) else: raise GeneralProxyError("SOCKS5 proxy server sent invalid data") @@ -602,7 +567,6 @@ class socksocket(_BaseSocket): Negotiates a connection through a SOCKS4 server. """ proxy_type, addr, port, rdns, username, password = self.proxy - dest_addr = utils.to_str(dest_addr) writer = self.makefile("wb") reader = self.makefile("rb", 0) # buffering=0 renamed in Python 3 @@ -611,7 +575,7 @@ class socksocket(_BaseSocket): remote_resolve = False try: addr_bytes = socket.inet_aton(dest_addr) - except socket.error as e: + except socket.error: # It's a DNS name. Check where it should be resolved. if rdns: addr_bytes = b"\x00\x00\x00\x01" @@ -632,7 +596,7 @@ class socksocket(_BaseSocket): # NOTE: This is actually an extension to the SOCKS4 protocol # called SOCKS4A and may not be supported in all cases. if remote_resolve: - writer.write(dest_addr.encode('utf-8') + b"\x00") + writer.write(dest_addr.encode('idna') + b"\x00") writer.flush() # Get the response from the server @@ -657,37 +621,30 @@ class socksocket(_BaseSocket): reader.close() writer.close() - def _negotiate_HTTP(self, dest_host, dest_port): + def _negotiate_HTTP(self, dest_addr, dest_port): """ Negotiates a connection through an HTTP server. NOTE: This currently only supports HTTP CONNECT-style proxies. """ - proxy_type, proxy_addr, port, rdns, username, password = self.proxy + proxy_type, addr, port, rdns, username, password = self.proxy # If we need to resolve locally, we do this now - dest_host = utils.to_bytes(dest_host) - if b":" not in dest_host and not rdns: - dest_addr = socket.gethostbyname(dest_host) - dest_addr = utils.to_bytes(dest_addr) - else: - dest_addr = dest_host + addr = dest_addr if rdns else socket.gethostbyname(dest_addr) http_headers = [ - (b"CONNECT " + utils.to_bytes(dest_addr) + b":" - + str(dest_port).encode() + b" HTTP/1.1"), - b"Host: " + dest_addr + b"CONNECT " + addr.encode('idna') + b":" + str(dest_port).encode() + b" HTTP/1.1", + b"Host: " + dest_addr.encode('idna') ] if username and password: - http_headers.append(b"Proxy-Authorization: basic " - + b64encode(username + b":" + password)) + http_headers.append(b"Proxy-Authorization: basic " + b64encode(username + b":" + password)) http_headers.append(b"\r\n") self.sendall(b"\r\n".join(http_headers)) # We just need the first line to check if the connection was successful - fobj = self.makefile("rb") + fobj = self.makefile() status_line = fobj.readline() fobj.close() @@ -695,11 +652,11 @@ class socksocket(_BaseSocket): raise GeneralProxyError("Connection closed unexpectedly") try: - proto, status_code, status_msg = status_line.split(b" ", 2) + proto, status_code, status_msg = status_line.split(" ", 2) except ValueError: raise GeneralProxyError("HTTP proxy server sent invalid response") - if not proto.startswith(b"HTTP/"): + if not proto.startswith("HTTP/"): raise GeneralProxyError("Proxy server does not appear to be an HTTP proxy") try: @@ -716,7 +673,7 @@ class socksocket(_BaseSocket): raise HTTPError(error) self.proxy_sockname = (b"0.0.0.0", 0) - self.proxy_peername = dest_addr, dest_port + self.proxy_peername = addr, dest_port _proxy_negotiators = { SOCKS4: _negotiate_SOCKS4, @@ -724,28 +681,27 @@ class socksocket(_BaseSocket): HTTP: _negotiate_HTTP } + def connect(self, dest_pair): """ Connects to the specified destination through a proxy. Uses the same API as socket's connect(). To select the proxy server, use set_proxy(). - dest_pair + dest_pair - 2-tuple of (IP/hostname, port). """ - if len(dest_pair) == 2: - # IPv4 - dest_addr, dest_port = dest_pair - elif len(dest_pair) == 4: - # IPv6 - dest_addr, dest_port, st_zero, st_stream = dest_pair - else: - raise GeneralProxyError("Invalid destination-connection (host, port) pair") + if len(dest_pair) != 2 or dest_pair[0].startswith("["): + # Probably IPv6, not supported -- raise an error, and hope + # Happy Eyeballs (RFC6555) makes sure at least the IPv4 + # connection works... + raise socket.error("PySocks doesn't support IPv6") + + dest_addr, dest_port = dest_pair if self.type == socket.SOCK_DGRAM: if not self._proxyconn: self.bind(("", 0)) - if self.resolve_dest: - dest_addr = socket.gethostbyname(dest_addr) + dest_addr = socket.gethostbyname(dest_addr) # If the host address is INADDR_ANY or similar, reset the peer # address so that packets are received from any peer @@ -755,30 +711,33 @@ class socksocket(_BaseSocket): self.proxy_peername = (dest_addr, dest_port) return - proxy_type, proxy_host, proxy_port, rdns, username, password = self.proxy - proxy_host = utils.to_bytes(proxy_host) + proxy_type, proxy_addr, proxy_port, rdns, username, password = self.proxy # Do a minimal input check first - if not dest_addr or not isinstance(dest_port, int): + if (not isinstance(dest_pair, (list, tuple)) + or len(dest_pair) != 2 + or not dest_addr + or not isinstance(dest_port, int)): raise GeneralProxyError("Invalid destination-connection (host, port) pair") + if proxy_type is None: # Treat like regular socket object + self.proxy_peername = dest_pair _BaseSocket.connect(self, (dest_addr, dest_port)) return - proxy_port = proxy_port or DEFAULT_PORTS.get(proxy_type) - if not proxy_port: - raise GeneralProxyError("Invalid proxy port") + proxy_addr = self._proxy_addr() try: # Initial connection to proxy server - proxy_ip = socket.gethostbyname(proxy_host) - _BaseSocket.connect(self, (proxy_ip, proxy_port)) + _BaseSocket.connect(self, proxy_addr) + except socket.error as error: # Error while connecting to proxy self.close() - proxy_server = "{0}:{1}".format(proxy_host, proxy_port) + proxy_addr, proxy_port = proxy_addr + proxy_server = "{0}:{1}".format(proxy_addr, proxy_port) printable_type = PRINTABLE_PROXY_TYPES[proxy_type] msg = "Error connecting to {0} proxy {1}".format(printable_type, @@ -800,8 +759,12 @@ class socksocket(_BaseSocket): self.close() raise - -if __name__ == "__main__": - name = "abc" - name2 = name.encode('idna') - print(name2) \ No newline at end of file + def _proxy_addr(self): + """ + Return proxy address to connect to as tuple object + """ + proxy_type, proxy_addr, proxy_port, rdns, username, password = self.proxy + proxy_port = proxy_port or DEFAULT_PORTS.get(proxy_type) + if not proxy_port: + raise GeneralProxyError("Invalid proxy type") + return proxy_addr, proxy_port diff --git a/api/sockslib/utils.py b/api/sockslib/utils.py deleted file mode 100644 index c843f9e..0000000 --- a/api/sockslib/utils.py +++ /dev/null @@ -1,385 +0,0 @@ -"""" -Code from https://github.com/XX-net/XX-Net/blob/master/code/default/lib/noarch/utils.py - -Copyright (c) [2022], [XX-Net] -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" - -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -import re -import os -import threading -from functools import reduce -from six import string_types - -ipv4_pattern = re.compile(br'^(\d{1,3})\.(\d{1,3})\.(\d{1,3})\.(\d{1,3})$') - -ipv6_pattern = re.compile(br""" - ^ - \s* # Leading whitespace - (?!.*::.*::) # Only a single whildcard allowed - (?:(?!:)|:(?=:)) # Colon iff it would be part of a wildcard - (?: # Repeat 6 times: - [0-9a-f]{0,4} # A group of at most four hexadecimal digits - (?:(?<=::)|(? 255: - return 0 - return 1 - else: - return 0 - - -def check_ip_valid6(ip): - """Copied from http://stackoverflow.com/a/319293/2755602""" - ip = to_bytes(ip) - - return ipv6_pattern.match(ip) is not None - - -def check_ip_valid(ip): - ip = to_bytes(ip) - if b'.' in ip: - return check_ip_valid4(ip) - else: - return check_ip_valid6(ip) - - -def get_ip_port(ip_str, port=443): - ip_str = to_bytes(ip_str) - if b"." in ip_str: - # ipv4 - if b":" in ip_str: - # format is ip:port - ps = ip_str.split(b":") - ip = ps[0] - port = ps[1] - else: - # format is ip - ip = ip_str - else: - # ipv6 - if b"[" in ip_str: - # format: [ab01:12:23:34::1] - # format: [ab01:12:23:34::1]:23 - - p1 = ip_str.find(b"[") - p2 = ip_str.find(b"]") - ip = ip_str[p1 + 1:p2] - port_str = ip_str[p2 + 1:] - if len(port_str) > 0: - port = port_str[1:] - else: - ip = ip_str - - return ip, int(port) - - -domain_allowed = re.compile("(?!-)[A-Z\d-]{1,63}(? 255: - return False - if hostname.endswith("."): - hostname = hostname[:-1] - - return all(domain_allowed.match(x) for x in hostname.split(".")) - - -def str2hex(data): - data = to_str(data) - return ":".join("{:02x}".format(ord(c)) for c in data) - - -def get_ip_maskc(ip_str): - head = ".".join(ip_str.split(".")[:-1]) - return head + ".0" - - -def split_ip(strline): - """从每组地址中分离出起始IP以及结束IP""" - begin = "" - end = "" - if "-" in strline: - num_regions = strline.split(".") - if len(num_regions) == 4: - "xxx.xxx.xxx-xxx.xxx-xxx" - begin = '' - end = '' - for region in num_regions: - if '-' in region: - s, e = region.split('-') - begin += '.' + s - end += '.' + e - else: - begin += '.' + region - end += '.' + region - begin = begin[1:] - end = end[1:] - - else: - "xxx.xxx.xxx.xxx-xxx.xxx.xxx.xxx" - begin, end = strline.split("-") - if 1 <= len(end) <= 3: - prefix = begin[0:begin.rfind(".")] - end = prefix + "." + end - - elif strline.endswith("."): - "xxx.xxx.xxx." - begin = strline + "0" - end = strline + "255" - elif "/" in strline: - "xxx.xxx.xxx.xxx/xx" - (ip, bits) = strline.split("/") - if check_ip_valid4(ip) and (0 <= int(bits) <= 32): - orgip = ip_string_to_num(ip) - end_bits = (1 << (32 - int(bits))) - 1 - begin_bits = 0xFFFFFFFF ^ end_bits - begin = ip_num_to_string(orgip & begin_bits) - end = ip_num_to_string(orgip | end_bits) - else: - "xxx.xxx.xxx.xxx" - begin = strline - end = strline - - return begin, end - - -def generate_random_lowercase(n): - min_lc = ord(b'a') - len_lc = 26 - ba = bytearray(os.urandom(n)) - for i, b in enumerate(ba): - ba[i] = min_lc + b % len_lc # convert 0..255 to 97..122 - # sys.stdout.buffer.write(ba) - return ba - - -class SimpleCondition(object): - def __init__(self): - self.lock = threading.Condition() - - def notify(self): - self.lock.acquire() - self.lock.notify() - self.lock.release() - - def wait(self): - self.lock.acquire() - self.lock.wait() - self.lock.release() - - -def split_domain(host): - host = to_bytes(host) - hl = host.split(b".") - return hl[0], b".".join(hl[1:]) - - -def ip_string_to_num(s): - """Convert dotted IPv4 address to integer.""" - return reduce(lambda a, b: a << 8 | b, list(map(int, s.split(".")))) - - -def ip_num_to_string(ip): - """Convert 32-bit integer to dotted IPv4 address.""" - return ".".join([str(ip >> n & 0xFF) for n in [24, 16, 8, 0]]) - - -private_ipv4_range = [ - ("10.0.0.0", "10.255.255.255"), - ("127.0.0.0", "127.255.255.255"), - ("169.254.0.0", "169.254.255.255"), - ("172.16.0.0", "172.31.255.255"), - ("192.168.0.0", "192.168.255.255") -] - -private_ipv6_range = [ - ("::1", "::1"), - ("fc00::", "fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") -] - -private_ipv4_range_bin = [] -for b, e in private_ipv4_range: - bb = ip_string_to_num(b) - ee = ip_string_to_num(e) - private_ipv4_range_bin.append((bb, ee)) - - -def is_private_ip(ip): - ip = to_str(ip) - try: - if "." in ip: - ip_bin = ip_string_to_num(ip) - for b, e in private_ipv4_range_bin: - if b <= ip_bin <= e: - return True - return False - else: - if ip == "::1": - return True - - fi = ip.find(":") - if fi != 4: - return False - - be = ip[0:2] - if be in ["fc", "fd"]: - return True - else: - return False - except Exception as e: - # print(("is_private_ip(%s), except:%r", ip, e)) - return False - - -import string - -printable = set(string.printable) - - -def get_printable(s): - return [x for x in s if x in printable] - - -def compare_version(version, reference_version): - try: - p = re.compile(r'([0-9]+)\.([0-9]+)\.([0-9]+)') - m1 = p.match(version) - m2 = p.match(reference_version) - v1 = list(map(int, list(map(m1.group, [1, 2, 3])))) - v2 = list(map(int, list(map(m2.group, [1, 2, 3])))) - - if v1 > v2: - return 1 - elif v1 < v2: - return -1 - else: - return 0 - except Exception as e: - print("older_or_equal fail: %s, %s" % (version, reference_version)) - raise e - - -def map_with_parameter(function, datas, args): - l = [] - for data in datas: - d_out = function(data, args) - l.append(d_out) - return l - - -def to_bytes(data, coding='utf-8'): - if isinstance(data, bytes): - return data - if isinstance(data, string_types): - return data.encode(coding) - if isinstance(data, dict): - return dict(map_with_parameter(to_bytes, data.items(), coding)) - if isinstance(data, tuple): - return tuple(map_with_parameter(to_bytes, data, coding)) - if isinstance(data, list): - return list(map_with_parameter(to_bytes, data, coding)) - if isinstance(data, int): - return to_bytes(str(data)) - if data is None: - return data - return bytes(data) - - -def to_str(data, coding='utf-8'): - if isinstance(data, string_types): - return data - if isinstance(data, bytes): - return data.decode(coding) - if isinstance(data, bytearray): - return data.decode(coding) - if isinstance(data, dict): - return dict(map_with_parameter(to_str, data.items(), coding)) - if isinstance(data, tuple): - return tuple(map_with_parameter(to_str, data, coding)) - if isinstance(data, list): - return list(map_with_parameter(to_str, data, coding)) - if isinstance(data, int): - return str(data) - if data is None: - return data - return str(data) - - -def bytes2str_only(data, coding='utf-8'): - if isinstance(data, bytes): - return data.decode(coding) - if isinstance(data, dict): - return dict(map_with_parameter(bytes2str_only, data.items(), coding)) - if isinstance(data, tuple): - return tuple(map_with_parameter(bytes2str_only, data, coding)) - if isinstance(data, list): - return list(map_with_parameter(bytes2str_only, data, coding)) - else: - return data - - -def merge_two_dict(x, y): - """Given two dictionaries, merge them into a new dict as a shallow copy.""" - z = x.copy() - z.update(y) - return z - - -if __name__ == '__main__': - # print(get_ip_port("1.2.3.4", 443)) - # print(get_ip_port("1.2.3.4:8443", 443)) - print((get_ip_port("[face:ab1:11::0]", 443))) - print((get_ip_port("ab01::1", 443))) - print((get_ip_port("[ab01:55::1]:8444", 443))) \ No newline at end of file diff --git a/api/transfer.py b/api/transfer.py new file mode 100644 index 0000000..376d24d --- /dev/null +++ b/api/transfer.py @@ -0,0 +1,30 @@ +import os +import httpx + +from dotenv import load_dotenv + +from starlette.responses import StreamingResponse +from starlette.background import BackgroundTask + +load_dotenv() + +EXCLUDED_HEADERS = [ + 'content-encoding', + 'content-length', + 'transfer-encoding', + 'connection' +] + +async def stream_api_response(request, target_endpoint: str='https://api.openai.com/v1'): + async with httpx.AsyncClient(timeout=120) as client: + async with client.stream( + method=request.method, + url=f'{target_endpoint}/{request.url.path}', + headers={ + 'Authorization': 'Bearer ' + os.getenv('CLOSEDAI_KEY'), + 'Content-Type': 'application/json' + }, + data=await request.body(), + ) as target_response: + target_response.raise_for_status() + diff --git a/requirements.txt b/requirements.txt index 137024f..6938018 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ fastapi -httpx +httpx[socks] openai python-dotenv rich starlette +win_inet_pton diff --git a/tests/__main__.py b/tests/__main__.py index ce0834d..e3ffe37 100644 --- a/tests/__main__.py +++ b/tests/__main__.py @@ -2,7 +2,7 @@ from typing import List -import openai +import openai as closedai import httpx PORT = 8000 @@ -43,12 +43,12 @@ def test_api(model: str=MODEL, messages: List[dict]=None) -> dict: return response.json()['choices'][0] def test_library(): - """Tests if the endpoint is working with the OpenAI library.""" + """Tests if the endpoint is working with the "Closed"AI library.""" - openai.api_base = ENDPOINT - openai.api_key = 'nv-LIBRARY-TEST' + closedai.api_base = ENDPOINT + closedai.api_key = 'nv-LIBRARY-TEST' - completion = openai.ChatCompletion.create( + completion = closedai.ChatCompletion.create( model=MODEL, messages=MESSAGES, ) diff --git a/tests/monkeypatch.py b/tests/monkeypatch.py deleted file mode 100644 index 4b32141..0000000 --- a/tests/monkeypatch.py +++ /dev/null @@ -1,31 +0,0 @@ -# Credit: @miss_articulate_python on Discord - -import configparser -import os -import pathlib -import openai - -# creating a config file, so we can store the api key and other settings -config_file = pathlib.Path(__file__).parent / 'config.ini' -config = configparser.ConfigParser() -config.read_dict({ - 'openai': { - 'api_base': 'http://ENDPOINT', - 'api_key': '', - 'reset_ip_every_request': 'false' - } -}) - -if config_file.exists(): - config.read(config_file) - -with open(config_file, 'w', encoding='utf8') as configfile: - config.write(configfile) - -# the normal patch that you apply -openai.api_base = config['openai']['api_base'] -openai.api_key = config['openai']['api_key'] - -# many modules lookup these environment variable, so we pre-emptively set them -os.environ['OPENAI_API_KEY'] = config['openai']['api_key'] -os.environ['OPENAI_API_BASE'] = config['openai']['api_base']