Works for BetterGPT

This commit is contained in:
nsde 2023-07-19 23:51:28 +02:00
parent 3cfdddab7c
commit 83a3c13abf
10 changed files with 109 additions and 248 deletions

17
Dockerfile Normal file
View file

@ -0,0 +1,17 @@
#
FROM python:3.10
#
WORKDIR /code
#
COPY ./requirements.txt /code/requirements.txt
#
RUN pip install . pip install --no-cache-dir --upgrade -r /code/requirements.txt
#
COPY ./app /code/app
#
CMD ["uvicorn", "api.main:app", "--host", "0.0.0.0", "--port", "2333"]

View file

@ -1,104 +0,0 @@
import json
import os
import random
import time
from pymongo import MongoClient
with open('./config.json', 'r') as file:
config = json.load(file)
class Keys:
# --- START OF CONFIG ---
MONGO_URI = os.getenv('MONGO_URI') or config.get('MONGO_URI')
# --- END OF CONFIG ---
locked_keys = set()
cache = {}
# Initialize MongoDB
client = MongoClient(MONGO_URI)
db = client.get_database('keys_db')
collection = db['keys']
def __init__(self, key: str, model: str, provider_name: str, ratelimit: int, url: str):
self.key = key
self.model = model
self.provider_name = provider_name
self.ratelimit = ratelimit
self.url = url
if not Keys.cache:
self._load_keys()
def _load_keys(self) -> None:
cursor = Keys.collection.find({}, {'_id': 0, 'key_value': 1, 'model': 1, 'provider_name': 1, 'ratelimit': 1, 'url': 1, 'last_used': 1})
for doc in cursor:
key_value = doc['key_value']
model = doc['model']
provider_name = doc['provider_name']
ratelimit = doc['ratelimit']
url = doc['url']
last_used = doc.get('last_used', 0)
key_data = {'provider_name': provider_name, 'ratelimit': ratelimit, 'url': url, 'last_used': last_used}
Keys.cache.setdefault(model, {}).setdefault(key_value, key_data)
def lock(self) -> None:
self.locked_keys.add(self.key)
def unlock(self) -> None:
self.locked_keys.remove(self.key)
def is_locked(self) -> bool:
return self.key in self.locked_keys
@staticmethod
def get(model: str) -> str:
key_candidates = list(Keys.cache.get(model, {}).keys())
random.shuffle(key_candidates)
current_time = time.time()
for key_candidate in key_candidates:
key_data = Keys.cache[model][key_candidate]
key = Keys(key_candidate, model, key_data['provider_name'], key_data['ratelimit'], key_data['url'])
time_since_last_used = current_time - key_data.get('last_used', 0)
if not key.is_locked() and time_since_last_used >= key.ratelimit:
key.lock()
key_data['last_used'] = current_time # Update last_used in the cache
Keys.collection.update_one(
{'key_value': key.key, 'model': key.model},
{'$set': {'last_used': current_time}} # Update last_used in the database
)
return {
'url': key.url,
'key_value': key.key
}
print(f"[WARN] No unlocked keys found for model '{model}' in get keys request!")
def delete(self) -> None:
Keys.collection.delete_one({'key_value': self.key, 'model': self.model})
# Update cache
try:
del Keys.cache[self.model][self.key]
except KeyError:
print(f"[WARN] Tried to remove a key from cache which was not present: {self.key}")
def save(self) -> None:
key_data = {
'provider_name': self.provider_name,
'ratelimit': self.ratelimit,
'url': self.url,
'last_used': 0 # Initialize last_used to 0 when saving a new key
}
Keys.collection.insert_one({'key_value': self.key, 'model': self.model, **key_data})
# Update cache
Keys.cache.setdefault(self.model, {}).setdefault(self.key, key_data)
# Usage example:
# os.environ['MONGO_URI'] = "mongodb://localhost:27017"
# key_instance = Keys("example_key", "gpt-4", "openai", "10", "https://whatever-openai-thing-is/chat/completions/")
# key_instance.save()
# key_value = Keys.get("gpt-4")

View file

@ -1,8 +1,8 @@
import os
import fastapi
"""FastAPI setup."""
import fastapi
import asyncio
from starlette.responses import StreamingResponse
from starlette.requests import Request
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
@ -25,9 +25,7 @@ app.add_middleware(
@app.on_event('startup')
async def startup_event():
"""Read up the API server."""
# security.enable_proxy()
# security.ip_protection_check()
# await security.ip_protection_check()
@app.get('/')
async def root():
@ -35,8 +33,15 @@ async def root():
return {
'status': 'ok',
'discord': 'https://discord.gg/85gdcd57Xr',
'github': 'https://github.com/Luna-OSS'
'readme': 'https://nova-oss.com'
}
@app.route('/v1')
async def api_root():
"""Returns the API root endpoint."""
return {
'status': 'ok',
}
app.add_route('/{path:path}', transfer.transfer_streaming_response, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])

View file

@ -1,12 +1,9 @@
"""This module contains the Proxy class, which represents a proxy."""
"""This module makes it easy to implement proxies by providing a class.."""
import os
import httpx
import socket
import asyncio
from sockslib import socks
from rich import print
from dotenv import load_dotenv
load_dotenv()
@ -28,43 +25,6 @@ class Proxy:
self.username = username
self.password = password
@property
def auth(self):
"""Returns the authentication part of the proxy URL, if the proxy has a username and password."""
return f'{self.username}:{self.password}@' if all([self.username, self.password]) else ''
@property
def protocol(self):
"""Makes sure the hostnames are resolved correctly using the proxy.
See https://stackoverflow.com/a/43266186
"""
return self.proxy_type# + 'h' if self.proxy_type.startswith('socks') else self.proxy_type
@property
def proxies(self):
"""Returns a dictionary of proxies, ready to be used with the requests library or httpx.
"""
url = f'{self.protocol}://{self.auth}{self.host}:{self.port}'
proxies_dict = {
'http://': url.replace('https', 'http') if self.proxy_type == 'https' else url,
'https://': url.replace('http', 'https') if self.proxy_type == 'http' else url
}
return proxies_dict
@property
def url(self):
"""Returns the proxy URL."""
return f'{self.protocol}://{self.auth}{self.host}:{self.port}'
def __str__(self):
return f'{self.proxy_type}://{len(self.auth) * "*"}{self.host}:{self.port}'
def __repr__(self):
return f'<Proxy type={self.proxy_type} host={self.host} port={self.port} username={self.username} password={len(self.password) * "*"}>'
active_proxy = Proxy(
proxy_type=os.getenv('PROXY_TYPE', 'http'),
host=os.getenv('PROXY_HOST', '127.0.0.1'),
@ -72,54 +32,3 @@ active_proxy = Proxy(
username=os.getenv('PROXY_USER'),
password=os.getenv('PROXY_PASS')
)
def activate_proxy() -> None:
socks.set_default_proxy(
proxy_type=socks.PROXY_TYPES[active_proxy.proxy_type.upper()],
addr=active_proxy.host,
port=active_proxy.port,
username=active_proxy.username,
password=active_proxy.password
)
socket.socket = socks.socksocket
def check_proxy():
"""Checks if the proxy is working."""
resp = httpx.get(
url='https://echo.hoppscotch.io/',
timeout=20,
proxies=active_proxy.proxies
)
resp.raise_for_status()
return resp.ok
async def check_api():
model = 'gpt-3.5-turbo'
messages = [
{
'role': 'user',
'content': 'Explain what a wormhole is.'
},
]
headers = {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + os.getenv('CLOSEDAI_KEY')
}
json_data = {
'model': model,
'messages': messages,
'stream': True
}
async with httpx.AsyncClient(timeout=20) as client:
async with client.stream("POST", 'https://api.openai.com/v1/chat/completions', headers=headers, json=json_data) as response:
response.raise_for_status()
async for chunk in response.aiter_text():
print(chunk.strip())
if __name__ == '__main__':
asyncio.run(check_api())

View file

@ -1,47 +1,45 @@
"""Security checks for the API. Checks if the IP is masked etc."""
import os
import httpx
import asyncio
from rich import print
import proxies
from tests import check_proxy
from dotenv import load_dotenv
load_dotenv()
is_proxy_enabled = False
is_proxy_enabled = bool(os.getenv('PROXY_HOST', None))
class InsecureIPError(Exception):
"""Raised when the IP address of the server is not secure."""
def ip_protection_check():
async def ip_protection_check():
"""Makes sure that the actual server IP address is not exposed to the public."""
if not is_proxy_enabled:
print('[yellow]WARN: The proxy is not enabled. \
Skipping IP protection check.[/yellow]')
return True
actual_ips = os.getenv('ACTUAL_IPS', '').split()
if actual_ips:
echo_response = httpx.get(
url='https://echo.hoppscotch.io/',
timeout=15
)
response_data = echo_response.json()
response_ip = response_data['headers']['x-forwarded-for']
# run the async function check_proxy() and get its result
response_ip = await check_proxy()
for actual_ip in actual_ips:
if actual_ip in response_data:
if actual_ip in response_ip:
raise InsecureIPError(f'IP pattern "{actual_ip}" is in the values of ACTUAL_IPS of the\
.env file. Enable a VPN or proxy to continue.')
if is_proxy_enabled:
print(f'[green]SUCCESS: The IP "{response_ip}" was detected, which seems to be a proxy. Great![/green]')
print(f'[green]GOOD: The IP "{response_ip}" was detected, which seems to be a proxy. Great![/green]')
return True
else:
print('[yellow]WARNING: ACTUAL_IPS is not set in the .env file or empty.\
print('[yellow]WARN: ACTUAL_IPS is not set in the .env file or empty.\
This means that the real IP of the server could be exposed. If you\'re using something\
like Cloudflare or Repl.it, you can ignore this warning.[/yellow]')
if __name__ == '__main__':
enable_proxy()
ip_protection_check()

View file

@ -1,6 +1,8 @@
"""Module for transferring requests to ClosedAI API"""
import os
import json
import logging
import aiohttp
import aiohttp_socks
@ -8,11 +10,21 @@ import proxies
from dotenv import load_dotenv
import starlette
from starlette.responses import StreamingResponse
from starlette.background import BackgroundTask
load_dotenv()
# log to "api.log" file
logging.basicConfig(
filename='api.log',
level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(name)s %(message)s'
)
logging.info('API started')
EXCLUDED_HEADERS = [
'content-encoding',
'content-length',
@ -33,7 +45,16 @@ proxy_connector = aiohttp_socks.ProxyConnector(
async def transfer_streaming_response(incoming_request, target_endpoint: str='https://api.openai.com/v1'):
"""Transfer a streaming response from the incoming request to the target endpoint"""
incoming_json_payload = await incoming_request.json()
if incoming_request.headers.get('Authorization') != f'Bearer {os.getenv("DEMO_AUTH")}':
return starlette.responses.Response(
status_code=403,
content='Invalid API key'
)
try:
incoming_json_payload = await incoming_request.json()
except json.decoder.JSONDecodeError:
incoming_json_payload = {}
async def receive_target_stream():
connector = aiohttp_socks.ProxyConnector(
@ -46,12 +67,15 @@ async def transfer_streaming_response(incoming_request, target_endpoint: str='ht
)
async with aiohttp.ClientSession(
connector=connector,
timeout=aiohttp.ClientTimeout(total=120),
timeout=aiohttp.ClientTimeout(total=int(os.getenv('TRANSFER_TIMEOUT', '120'))),
raise_for_status=True
) as session:
target_url = f'{target_endpoint}{incoming_request.url.path}'.replace('/v1/v1', '/v1')
logging.info('TRANSFER %s -> %s', incoming_request.url.path, target_url)
async with session.request(
method=incoming_request.method,
url=f'{target_endpoint}/{incoming_request.url.path}',
url=target_url,
json=incoming_json_payload,
headers={
'Content-Type': 'application/json',
@ -60,6 +84,7 @@ async def transfer_streaming_response(incoming_request, target_endpoint: str='ht
) as response:
async for chunk in response.content.iter_any():
chunk = f'{chunk.decode("utf8")}\n\n'
logging.debug(chunk)
yield chunk
return StreamingResponse(

View file

@ -1,5 +1,5 @@
import sys
import os
port = sys.argv[1] if len(sys.argv) > 1 else 8000
port = sys.argv[1] if len(sys.argv) > 1 else 2333
os.system(f'cd api && uvicorn main:app --reload --host 0.0.0.0 --port {port}')

View file

@ -1,3 +0,0 @@
{
"MONGO_URI": ""
}

View file

@ -1,13 +1,13 @@
import setuptools
with open('README.md', 'r') as fh:
with open('README.md', 'r', encoding='utf8') as fh:
long_description = fh.read()
setuptools.setup(
name='nova-api',
version='0.0.1',
author='Luna OSS',
author_email='nsde@dmc.chat',
author='NovaOSS Contributors',
author_email='owner@nova-oss.com',
description='Nova API Server',
long_description=long_description,
long_description_content_type='text/markdown',

View file

@ -1,35 +1,42 @@
"""Tests the API."""
from typing import List
import os
import openai as closedai
import httpx
PORT = 8000
from typing import List
from dotenv import load_dotenv
load_dotenv()
MODEL = 'gpt-3.5-turbo'
MESSAGES = [
{
'role': 'system',
'content': 'Always answer with "3", no matter what the user asks for. No exceptions. Just answer with the number "3". Nothing else. Just "3". No punctuation.'
},
{
'role': 'user',
'content': 'Hello!',
'content': '1+1=',
},
]
ENDPOINT = f'http://localhost:{PORT}'
api_endpoint = 'http://localhost:2333'
def test_server():
"""Tests if the API is running."""
"""Tests if the API server is running."""
try:
return httpx.get(f'{ENDPOINT}').json()['status'] == 'ok'
return httpx.get(f'{api_endpoint}').json()['status'] == 'ok'
except httpx.ConnectError as exc:
raise ConnectionError(f'API is not running on port {PORT}.') from exc
raise ConnectionError(f'API is not running on port {api_endpoint}.') from exc
def test_api(model: str=MODEL, messages: List[dict]=None) -> dict:
"""Tests an API endpoint."""
"""Tests an API api_endpoint."""
headers = {
'Content-Type': 'application/json',
'Authorization': 'nv-API-TEST',
'Authorization': 'Bearer ' + os.getenv('DEMO_AUTH', 'nv-API-TEST'),
}
json_data = {
@ -38,20 +45,26 @@ def test_api(model: str=MODEL, messages: List[dict]=None) -> dict:
'stream': True,
}
response = httpx.post(f'{ENDPOINT}/chat/completions', headers=headers, json=json_data, timeout=20)
response = httpx.post(
url=f'{api_endpoint}/chat/completions',
headers=headers,
json=json_data,
timeout=20
)
response.raise_for_status()
return response
def test_library():
"""Tests if the endpoint is working with the "Closed"AI library."""
"""Tests if the api_endpoint is working with the Python library."""
closedai.api_base = ENDPOINT
closedai.api_key = 'nv-LIBRARY-TEST'
closedai.api_base = api_endpoint
closedai.api_key = os.getenv('DEMO_AUTH', 'nv-LIB-TEST')
completion = closedai.ChatCompletion.create(
model=MODEL,
messages=MESSAGES,
stream=True,
)
return completion.choices[0]
@ -59,9 +72,10 @@ def test_library():
def test_all():
"""Runs all tests."""
print(test_server())
# print(test_server())
print(test_api())
print(test_library())
# print(test_library())
if __name__ == '__main__':
api_endpoint = 'https://api.nova-oss.com'
test_all()