diff --git a/api/core.py b/api/core.py index b8011d3..e343e04 100644 --- a/api/core.py +++ b/api/core.py @@ -52,8 +52,6 @@ async def get_users(discord_id: int, incoming_request: fastapi.Request): if not user: return await errors.error(404, 'Discord user not found in the API database.', 'Check the `discord_id` parameter.') - print(type(user)) - print(user) return user async def new_user_webhook(user: dict) -> None: diff --git a/api/transfer.py b/api/handler.py similarity index 98% rename from api/transfer.py rename to api/handler.py index 4d35a0c..ff725dc 100644 --- a/api/transfer.py +++ b/api/handler.py @@ -21,7 +21,7 @@ models_list = json.load(open('models.json', encoding='utf8')) with open('config/config.yml', encoding='utf8') as f: config = yaml.safe_load(f) -async def handle(incoming_request): +async def handle(incoming_request: fastapi.Request): """ ### Transfer a streaming response Takes the request from the incoming request to the target endpoint. diff --git a/api/helpers/network.py b/api/helpers/network.py index 997f497..50859db 100644 --- a/api/helpers/network.py +++ b/api/helpers/network.py @@ -1,4 +1,3 @@ -import base64 import asyncio async def get_ip(request) -> str: @@ -17,3 +16,20 @@ async def get_ip(request) -> str: detected_ip = next((i for i in possible_ips if i), None) return detected_ip + +def get_ip_sync(request) -> str: + """Get the IP address of the incoming request.""" + + xff = None + if request.headers.get('x-forwarded-for'): + xff, *_ = request.headers['x-forwarded-for'].split(', ') + + possible_ips = [ + xff, + request.headers.get('cf-connecting-ip'), + request.client.host + ] + + detected_ip = next((i for i in possible_ips if i), None) + + return detected_ip diff --git a/api/main.py b/api/main.py index dbce369..fc76e8a 100644 --- a/api/main.py +++ b/api/main.py @@ -6,10 +6,15 @@ import pydantic from rich import print from dotenv import load_dotenv from bson.objectid import ObjectId +from slowapi.errors import RateLimitExceeded +from slowapi.middleware import SlowAPIMiddleware from fastapi.middleware.cors import CORSMiddleware +from slowapi import Limiter, _rate_limit_exceeded_handler + +from helpers import network import core -import transfer +import handler load_dotenv() @@ -25,6 +30,17 @@ app.add_middleware( app.include_router(core.router) +limiter = Limiter( + swallow_errors=True, + key_func=network.get_ip_sync, default_limits=[ + '2/second', + '20/minute', + '300/hour' +]) +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) +app.add_middleware(SlowAPIMiddleware) + @app.on_event('startup') async def startup_event(): """Runs when the API starts up.""" @@ -45,4 +61,6 @@ async def root(): 'ping': 'pong' } -app.add_route('/v1/{path:path}', transfer.handle, ['GET', 'POST', 'PUT', 'DELETE', 'PATCH']) +@app.route('/v1/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH']) +async def v1_handler(request: fastapi.Request): + return await handler.handle(request) diff --git a/api/streaming.py b/api/streaming.py index 5075aa3..d6a32a4 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -70,7 +70,7 @@ async def stream( yield await chat.create_chat_chunk(chat_id=chat_id, model=model, content=chat.CompletionStart) yield await chat.create_chat_chunk(chat_id=chat_id, model=model, content=None) - json_response = {'error': 'No JSON response could be received'} + json_response = {} headers = { 'Content-Type': 'application/json', @@ -110,22 +110,6 @@ async def stream( # We haven't done any requests as of right now, everything until now was just preparation # Here, we process the request async with aiohttp.ClientSession(connector=proxies.get_proxy().connector) as session: - # try: - # async with session.get( - # url='https://checkip.amazonaws.com', - # timeout=aiohttp.ClientTimeout( - # connect=0.4, - # total=0.7 - # ) - # ) as response: - # for actual_ip in os.getenv('ACTUAL_IPS', '').split(' '): - # if actual_ip in await response.text(): - # raise ValueError(f'Proxy {response.text()} is transparent!') - - # except Exception as exc: - # print(f'[!] proxy {proxies.get_proxy()} error - ({type(exc)} {exc})') - # continue - try: async with session.request( method=target_request.get('method', 'POST'), @@ -172,15 +156,19 @@ async def stream( break except ProxyError as exc: - print('[!] aiohttp came up with a dumb excuse to not work again ("pRoXy ErRor")') + print('[!] aiohttp ProxyError') continue except ConnectionResetError as exc: - print('[!] aiohttp came up with a dumb excuse to not work again ("cOnNeCtIoN rEsEt")') + print('[!] aiohttp ConnectionResetError') continue except aiohttp.client_exceptions.ClientConnectionError: - print('[!] aiohttp came up with a dumb excuse to not work again ("cOnNeCtIoN cLosEd")') + print('[!] aiohttp ClientConnectionError') + continue + + if not json_response and is_chat and is_stream: + print('[!] chat response is empty') continue if is_chat and is_stream: diff --git a/setup.py b/setup.py index 8ae4abc..2a208e4 100644 --- a/setup.py +++ b/setup.py @@ -1,20 +1,23 @@ -import setuptools +from fastapi import FastAPI +from fastapi.responses import PlainTextResponse +from fastapi.requests import Request +from fastapi.responses import Response +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.util import get_remote_address +from slowapi.errors import RateLimitExceeded -with open('README.md', 'r', encoding='utf8') as fh: - long_description = fh.read() +limiter = Limiter(key_func=lambda: "test", default_limits=["5/minute"]) +app = FastAPI() +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) -setuptools.setup( - name='nova-api', - version='0.0.1', - author='NovaOSS Contributors', - author_email='owner@nova-oss.com', - description='Nova API Server', - long_description=long_description, - long_description_content_type='text/markdown', - packages=setuptools.find_packages(), - classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', - ] -) +# Note: the route decorator must be above the limit decorator, not below it +@app.get("/home") +@limiter.limit("5/minute") +async def homepage(request: Request): + return PlainTextResponse("test") + +@app.get("/mars") +@limiter.limit("5/minute") +async def homepage(request: Request, response: Response): + return {"key": "value"}