mirror of
https://github.com/NovaOSS/nova-api.git
synced 2024-11-25 18:53:58 +01:00
Fixed function calling
This commit is contained in:
parent
98b5614345
commit
8b325d6b81
|
@ -1,30 +0,0 @@
|
||||||
import json
|
|
||||||
|
|
||||||
from helpers import chat
|
|
||||||
|
|
||||||
async def process_chunks(
|
|
||||||
chunks,
|
|
||||||
is_chat: bool,
|
|
||||||
chat_id: int,
|
|
||||||
target_request: dict,
|
|
||||||
model: str=None,
|
|
||||||
):
|
|
||||||
"""This function processes the response chunks from the providers and yields them.
|
|
||||||
"""
|
|
||||||
async for chunk in chunks:
|
|
||||||
chunk = chunk.decode("utf8").strip()
|
|
||||||
send = False
|
|
||||||
|
|
||||||
if is_chat and '{' in chunk:
|
|
||||||
data = json.loads(chunk.split('data: ')[1])
|
|
||||||
chunk = chunk.replace(data['id'], chat_id)
|
|
||||||
send = True
|
|
||||||
|
|
||||||
if target_request['module'] == 'twa' and data.get('text'):
|
|
||||||
chunk = await chat.create_chat_chunk(chat_id=chat_id, model=model, content=['text'])
|
|
||||||
|
|
||||||
if (not data['choices'][0]['delta']) or data['choices'][0]['delta'] == {'role': 'assistant'}:
|
|
||||||
send = False
|
|
||||||
|
|
||||||
if send and chunk:
|
|
||||||
yield chunk + '\n\n'
|
|
35
api/core.py
35
api/core.py
|
@ -126,24 +126,23 @@ async def run_checks(incoming_request: fastapi.Request):
|
||||||
if auth_error:
|
if auth_error:
|
||||||
return auth_error
|
return auth_error
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
funcs = [
|
||||||
|
checks.client.test_chat_non_stream_gpt4,
|
||||||
|
checks.client.test_chat_stream_gpt3,
|
||||||
|
checks.client.test_function_calling,
|
||||||
|
checks.client.test_image_generation,
|
||||||
|
checks.client.test_speech_to_text,
|
||||||
|
checks.client.test_models
|
||||||
|
]
|
||||||
|
|
||||||
|
for func in funcs:
|
||||||
try:
|
try:
|
||||||
chat = await checks.client.test_chat()
|
result = await func()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print(exc)
|
results[func.__name__] = str(exc)
|
||||||
chat = None
|
else:
|
||||||
|
results[func.__name__] = result
|
||||||
|
|
||||||
try:
|
return results
|
||||||
moderation = await checks.client.test_api_moderation()
|
|
||||||
except Exception:
|
|
||||||
moderation = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
models = await checks.client.test_models()
|
|
||||||
except Exception:
|
|
||||||
models = None
|
|
||||||
|
|
||||||
return {
|
|
||||||
'chat/completions': chat,
|
|
||||||
'models': models,
|
|
||||||
'moderations': moderation,
|
|
||||||
}
|
|
||||||
|
|
|
@ -124,7 +124,14 @@ async def handle(incoming_request: fastapi.Request):
|
||||||
inp = payload.get('input', payload.get('prompt', ''))
|
inp = payload.get('input', payload.get('prompt', ''))
|
||||||
|
|
||||||
if isinstance(payload.get('messages'), list):
|
if isinstance(payload.get('messages'), list):
|
||||||
inp = '\n'.join([message['content'] for message in payload['messages']])
|
inp = ''
|
||||||
|
|
||||||
|
for message in payload.get('messages', []):
|
||||||
|
if message.get('role') == 'user':
|
||||||
|
inp += message.get('content', '') + '\n'
|
||||||
|
|
||||||
|
if 'functions' in payload:
|
||||||
|
inp += '\n'.join([function.get('description', '') for function in payload.get('functions', [])])
|
||||||
|
|
||||||
if inp and len(inp) > 2 and not inp.isnumeric():
|
if inp and len(inp) > 2 and not inp.isnumeric():
|
||||||
policy_violation = await moderation.is_policy_violated(inp)
|
policy_violation = await moderation.is_policy_violated(inp)
|
||||||
|
@ -148,7 +155,7 @@ async def handle(incoming_request: fastapi.Request):
|
||||||
path=path,
|
path=path,
|
||||||
payload=payload,
|
payload=payload,
|
||||||
credits_cost=cost,
|
credits_cost=cost,
|
||||||
input_tokens=-1,
|
input_tokens=0,
|
||||||
incoming_request=incoming_request,
|
incoming_request=incoming_request,
|
||||||
),
|
),
|
||||||
media_type=media_type
|
media_type=media_type
|
||||||
|
|
17
api/main.py
17
api/main.py
|
@ -2,11 +2,9 @@
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import pydantic
|
import pydantic
|
||||||
import functools
|
|
||||||
|
|
||||||
from rich import print
|
from rich import print
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from json import JSONDecodeError
|
|
||||||
from bson.objectid import ObjectId
|
from bson.objectid import ObjectId
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
from slowapi.middleware import SlowAPIMiddleware
|
from slowapi.middleware import SlowAPIMiddleware
|
||||||
|
@ -17,7 +15,6 @@ from helpers import network
|
||||||
|
|
||||||
import core
|
import core
|
||||||
import handler
|
import handler
|
||||||
import moderation
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -66,17 +63,5 @@ async def root():
|
||||||
|
|
||||||
@app.route('/v1/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
@app.route('/v1/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE', 'PATCH'])
|
||||||
async def v1_handler(request: fastapi.Request):
|
async def v1_handler(request: fastapi.Request):
|
||||||
res = await handler.handle(request)
|
res = await handler.handle(incoming_request=request)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
@functools.lru_cache()
|
|
||||||
@app.post('/moderate')
|
|
||||||
async def moderate(request: fastapi.Request):
|
|
||||||
try:
|
|
||||||
prompt = await request.json()
|
|
||||||
prompt = prompt['text']
|
|
||||||
except (KeyError, JSONDecodeError):
|
|
||||||
return fastapi.Response(status_code=400)
|
|
||||||
|
|
||||||
result = await moderation.is_policy_violated__own_model(prompt)
|
|
||||||
return result or ''
|
|
||||||
|
|
|
@ -11,7 +11,6 @@ import starlette
|
||||||
from rich import print
|
from rich import print
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
import chunks
|
|
||||||
import proxies
|
import proxies
|
||||||
import provider_auth
|
import provider_auth
|
||||||
import after_request
|
import after_request
|
||||||
|
@ -21,24 +20,6 @@ from helpers import network, chat, errors
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
## Loads config which contains rate limits
|
|
||||||
with open('config/config.yml', encoding='utf8') as f:
|
|
||||||
config = yaml.safe_load(f)
|
|
||||||
|
|
||||||
## Where all rate limit requested data will be stored.
|
|
||||||
# Rate limit data is **not persistent** (It will be deleted on server stop/restart).
|
|
||||||
user_last_request_time = {}
|
|
||||||
|
|
||||||
DEMO_PAYLOAD = {
|
|
||||||
'model': 'gpt-3.5-turbo',
|
|
||||||
'messages': [
|
|
||||||
{
|
|
||||||
'role': 'user',
|
|
||||||
'content': '1+1='
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
async def respond(
|
async def respond(
|
||||||
path: str='/v1/chat/completions',
|
path: str='/v1/chat/completions',
|
||||||
user: dict=None,
|
user: dict=None,
|
||||||
|
@ -52,27 +33,22 @@ async def respond(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
is_chat = False
|
is_chat = False
|
||||||
is_stream = payload.get('stream', False)
|
|
||||||
|
|
||||||
model = None
|
model = None
|
||||||
|
is_stream = False
|
||||||
|
|
||||||
if 'chat/completions' in path:
|
if 'chat/completions' in path:
|
||||||
is_chat = True
|
is_chat = True
|
||||||
model = payload['model']
|
model = payload['model']
|
||||||
|
|
||||||
if is_chat and is_stream:
|
|
||||||
chat_id = await chat.create_chat_id()
|
|
||||||
yield await chat.create_chat_chunk(chat_id=chat_id, model=model, content=chat.CompletionStart)
|
|
||||||
yield await chat.create_chat_chunk(chat_id=chat_id, model=model, content=None)
|
|
||||||
|
|
||||||
json_response = {}
|
json_response = {}
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
'User-Agent': 'null'
|
'User-Agent': 'axios/0.21.1',
|
||||||
}
|
}
|
||||||
|
|
||||||
for _ in range(5):
|
for _ in range(10):
|
||||||
# Load balancing: randomly selecting a suitable provider
|
# Load balancing: randomly selecting a suitable provider
|
||||||
# If the request is a chat completion, then we need to load balance between chat providers
|
# If the request is a chat completion, then we need to load balance between chat providers
|
||||||
# If the request is an organic request, then we need to load balance between organic providers
|
# If the request is an organic request, then we need to load balance between organic providers
|
||||||
|
@ -115,10 +91,11 @@ async def respond(
|
||||||
cookies=target_request.get('cookies'),
|
cookies=target_request.get('cookies'),
|
||||||
ssl=False,
|
ssl=False,
|
||||||
timeout=aiohttp.ClientTimeout(
|
timeout=aiohttp.ClientTimeout(
|
||||||
connect=0.5,
|
connect=0.3,
|
||||||
total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
|
total=float(os.getenv('TRANSFER_TIMEOUT', '500'))
|
||||||
),
|
),
|
||||||
) as response:
|
) as response:
|
||||||
|
is_stream = response.content_type == 'text/event-stream'
|
||||||
|
|
||||||
if response.status == 429:
|
if response.status == 429:
|
||||||
continue
|
continue
|
||||||
|
@ -144,35 +121,27 @@ async def respond(
|
||||||
if 'Too Many Requests' in str(exc):
|
if 'Too Many Requests' in str(exc):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async for chunk in chunks.process_chunks(
|
async for chunk in response.content.iter_any():
|
||||||
chunks=response.content.iter_any(),
|
chunk = chunk.decode('utf8').strip()
|
||||||
is_chat=is_chat,
|
yield chunk + '\n\n'
|
||||||
chat_id=chat_id,
|
|
||||||
model=model,
|
|
||||||
target_request=target_request
|
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
# print(f'[!] {type(exc)} - {exc}')
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if (not json_response) and is_chat:
|
if (not json_response) and is_chat:
|
||||||
print('[!] chat response is empty')
|
print('[!] chat response is empty')
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
yield await errors.yield_error(500, 'Sorry, the API is not responding.', 'Please try again later.')
|
yield await errors.yield_error(500, 'Sorry, the provider is not responding. We\'re possibly getting rate-limited.', 'Please try again later.')
|
||||||
return
|
return
|
||||||
|
|
||||||
if is_chat and is_stream:
|
|
||||||
yield await chat.create_chat_chunk(chat_id=chat_id, model=model, content=chat.CompletionStop)
|
|
||||||
yield 'data: [DONE]\n\n'
|
|
||||||
|
|
||||||
if (not is_stream) and json_response:
|
if (not is_stream) and json_response:
|
||||||
yield json.dumps(json_response)
|
yield json.dumps(json_response)
|
||||||
|
|
||||||
|
print(f'[+] {path} -> {model or ""}')
|
||||||
|
|
||||||
await after_request.after_request(
|
await after_request.after_request(
|
||||||
incoming_request=incoming_request,
|
incoming_request=incoming_request,
|
||||||
target_request=target_request,
|
target_request=target_request,
|
||||||
|
@ -183,5 +152,3 @@ async def respond(
|
||||||
is_chat=is_chat,
|
is_chat=is_chat,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f'[+] {path} -> {model or ""}')
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import json
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -10,6 +11,7 @@ import traceback
|
||||||
from rich import print
|
from rich import print
|
||||||
from typing import List
|
from typing import List
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
@ -43,12 +45,12 @@ async def test_server():
|
||||||
else:
|
else:
|
||||||
return time.perf_counter() - request_start
|
return time.perf_counter() - request_start
|
||||||
|
|
||||||
async def test_chat_non_stream(model: str=MODEL, messages: List[dict]=None) -> dict:
|
async def test_chat_non_stream_gpt4() -> float:
|
||||||
"""Tests an API api_endpoint."""
|
"""Tests non-streamed chat completions with the GPT-4 model."""
|
||||||
|
|
||||||
json_data = {
|
json_data = {
|
||||||
'model': model,
|
'model': 'gpt-4',
|
||||||
'messages': messages or MESSAGES,
|
'messages': MESSAGES,
|
||||||
'stream': False
|
'stream': False
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +68,30 @@ async def test_chat_non_stream(model: str=MODEL, messages: List[dict]=None) -> d
|
||||||
assert '2' in response.json()['choices'][0]['message']['content'], 'The API did not return a correct response.'
|
assert '2' in response.json()['choices'][0]['message']['content'], 'The API did not return a correct response.'
|
||||||
return time.perf_counter() - request_start
|
return time.perf_counter() - request_start
|
||||||
|
|
||||||
async def test_sdxl():
|
async def test_chat_stream_gpt3() -> float:
|
||||||
|
"""Tests the text stream endpoint with the GPT-3.5-Turbo model."""
|
||||||
|
|
||||||
|
json_data = {
|
||||||
|
'model': 'gpt-3.5-turbo',
|
||||||
|
'messages': MESSAGES,
|
||||||
|
'stream': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
request_start = time.perf_counter()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
url=f'{api_endpoint}/chat/completions',
|
||||||
|
headers=HEADERS,
|
||||||
|
json=json_data,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
assert '2' in response.json()['choices'][0]['message']['content'], 'The API did not return a correct response.'
|
||||||
|
return time.perf_counter() - request_start
|
||||||
|
|
||||||
|
async def test_image_generation() -> float:
|
||||||
"""Tests the image generation endpoint with the SDXL model."""
|
"""Tests the image generation endpoint with the SDXL model."""
|
||||||
|
|
||||||
json_data = {
|
json_data = {
|
||||||
|
@ -89,6 +114,48 @@ async def test_sdxl():
|
||||||
assert '://' in response.json()['data'][0]['url']
|
assert '://' in response.json()['data'][0]['url']
|
||||||
return time.perf_counter() - request_start
|
return time.perf_counter() - request_start
|
||||||
|
|
||||||
|
class StepByStepAIResponse(BaseModel):
|
||||||
|
"""Demo response structure for the function calling test."""
|
||||||
|
title: str
|
||||||
|
steps: List[str]
|
||||||
|
|
||||||
|
async def test_function_calling():
|
||||||
|
"""Tests function calling functionality with newer GPT models."""
|
||||||
|
|
||||||
|
json_data = {
|
||||||
|
'stream': False,
|
||||||
|
'model': 'gpt-3.5-turbo-0613',
|
||||||
|
'messages': [
|
||||||
|
{"role": "user", "content": "Explain how to assemble a PC"}
|
||||||
|
],
|
||||||
|
'functions': [
|
||||||
|
{
|
||||||
|
'name': 'get_answer_for_user_query',
|
||||||
|
'description': 'Get user answer in series of steps',
|
||||||
|
'parameters': StepByStepAIResponse.schema()
|
||||||
|
}
|
||||||
|
],
|
||||||
|
'function_call': {'name': 'get_answer_for_user_query'}
|
||||||
|
}
|
||||||
|
|
||||||
|
request_start = time.perf_counter()
|
||||||
|
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
url=f'{api_endpoint}/chat/completions',
|
||||||
|
headers=HEADERS,
|
||||||
|
json=json_data,
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
res = response.json()
|
||||||
|
output = json.loads(res['choices'][0]['message']['function_call']['arguments'])
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
assert output.get('title') and output.get('steps'), 'The API did not return a correct response.'
|
||||||
|
return time.perf_counter() - request_start
|
||||||
|
|
||||||
async def test_models():
|
async def test_models():
|
||||||
"""Tests the models endpoint."""
|
"""Tests the models endpoint."""
|
||||||
|
|
||||||
|
@ -122,14 +189,17 @@ async def demo():
|
||||||
else:
|
else:
|
||||||
raise ConnectionError('API Server is not running.')
|
raise ConnectionError('API Server is not running.')
|
||||||
|
|
||||||
|
# print('[lightblue]Checking if function calling works...')
|
||||||
|
# print(await test_function_calling())
|
||||||
|
|
||||||
print('Checking non-streamed chat completions...')
|
print('Checking non-streamed chat completions...')
|
||||||
print(await test_chat_non_stream())
|
print(await test_chat_non_stream_gpt4())
|
||||||
|
|
||||||
# print('[lightblue]Checking if SDXL image generation works...')
|
print('Checking streamed chat completions...')
|
||||||
# print(await test_sdxl())
|
print(await test_chat_stream_gpt3())
|
||||||
|
|
||||||
# print('[lightblue]Checking if the moderation endpoint works...')
|
print('[lightblue]Checking if image generation works...')
|
||||||
# print(await test_api_moderation())
|
print(await test_image_generation())
|
||||||
|
|
||||||
print('Checking the models endpoint...')
|
print('Checking the models endpoint...')
|
||||||
print(await test_models())
|
print(await test_models())
|
||||||
|
|
82
playground/functioncalling.py
Normal file
82
playground/functioncalling.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import openai
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
openai.api_base = 'http://localhost:2332/v1'
|
||||||
|
openai.api_key = os.environ['NOVA_KEY']
|
||||||
|
|
||||||
|
# Example dummy function hard coded to return the same weather
|
||||||
|
# In production, this could be your backend API or an external API
|
||||||
|
def get_current_weather(location, unit='fahrenheit'):
|
||||||
|
"""Get the current weather in a given location"""
|
||||||
|
weather_info = {
|
||||||
|
'location': location,
|
||||||
|
'temperature': '72',
|
||||||
|
'unit': unit,
|
||||||
|
'forecast': ['sunny', 'windy'],
|
||||||
|
}
|
||||||
|
return json.dumps(weather_info)
|
||||||
|
|
||||||
|
def run_conversation():
|
||||||
|
# Step 1: send the conversation and available functions to GPT
|
||||||
|
messages = [{'role': 'user', 'content': 'What\'s the weather like in Boston?'}]
|
||||||
|
functions = [
|
||||||
|
{
|
||||||
|
'name': 'get_current_weather',
|
||||||
|
'description': 'Get the current weather in a given location',
|
||||||
|
'parameters': {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'location': {
|
||||||
|
'type': 'string',
|
||||||
|
'description': 'The city and state, e.g. San Francisco, CA',
|
||||||
|
},
|
||||||
|
'unit': {'type': 'string', 'enum': ['celsius', 'fahrenheit']},
|
||||||
|
},
|
||||||
|
'required': ['location'],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
response = openai.ChatCompletion.create(
|
||||||
|
model='gpt-3.5-turbo-0613',
|
||||||
|
messages=messages,
|
||||||
|
functions=functions,
|
||||||
|
function_call='auto', # auto is default, but we'll be explicit
|
||||||
|
)
|
||||||
|
response_message = response['choices'][0]['message']
|
||||||
|
|
||||||
|
# Step 2: check if GPT wanted to call a function
|
||||||
|
if response_message.get('function_call'):
|
||||||
|
# Step 3: call the function
|
||||||
|
# Note: the JSON response may not always be valid; be sure to handle errors
|
||||||
|
available_functions = {
|
||||||
|
'get_current_weather': get_current_weather,
|
||||||
|
} # only one function in this example, but you can have multiple
|
||||||
|
function_name = response_message['function_call']['name']
|
||||||
|
fuction_to_call = available_functions[function_name]
|
||||||
|
function_args = json.loads(response_message['function_call']['arguments'])
|
||||||
|
function_response = fuction_to_call(
|
||||||
|
location=function_args.get('location'),
|
||||||
|
unit=function_args.get('unit'),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: send the info on the function call and function response to GPT
|
||||||
|
messages.append(response_message) # extend conversation with assistant's reply
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
'role': 'function',
|
||||||
|
'name': function_name,
|
||||||
|
'content': function_response,
|
||||||
|
}
|
||||||
|
) # extend conversation with function response
|
||||||
|
second_response = openai.ChatCompletion.create(
|
||||||
|
model='gpt-3.5-turbo-0613',
|
||||||
|
messages=messages,
|
||||||
|
) # get a new response from GPT where it can see the function response
|
||||||
|
return second_response
|
||||||
|
|
||||||
|
print(run_conversation())
|
Loading…
Reference in a new issue