mirror of
https://github.com/NovaOSS/nova-python.git
synced 2024-11-25 17:23:59 +01:00
Made seconds_until_timeout an int and added a custom type for Chat Completions
This commit is contained in:
parent
59b7c6073a
commit
2d113df795
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "nova-python"
|
||||
version = "0.1.2"
|
||||
version = "0.1.3"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
@ -12,3 +12,4 @@ crate-type = ["cdylib"]
|
|||
pyo3 = "0.19.0"
|
||||
reqwest = "0.11.18"
|
||||
tokio = { version = "1.29.1", features = ["rt-multi-thread", "time"] }
|
||||
serde_json = "1.0.104"
|
61
src/lib.rs
61
src/lib.rs
|
@ -54,7 +54,7 @@ impl NovaClient {
|
|||
}
|
||||
|
||||
// Used to make a request to the Nova API
|
||||
fn make_request(&self, endpoint: Endpoints, model: Models, data: Vec<Py<PyDict>>, seconds_until_timeout: Option<String>) -> PyResult<String> {
|
||||
fn make_request(&self, endpoint: Endpoints, model: Models, data: Vec<Py<PyDict>>, seconds_until_timeout: Option<usize>) -> PyResult<PyObject> {
|
||||
if !model_is_compatible(&endpoint, &model) {
|
||||
return Err(NovaClient::get_endpoint_not_compatible_error());
|
||||
}
|
||||
|
@ -64,13 +64,13 @@ impl NovaClient {
|
|||
|
||||
let rt = tokio::runtime::Runtime::new().unwrap();
|
||||
let seconds_until_timeout = match seconds_until_timeout {
|
||||
Some(seconds_until_timeout) => seconds_until_timeout.parse::<u64>().unwrap(),
|
||||
Some(seconds_until_timeout) => seconds_until_timeout,
|
||||
None => 30
|
||||
};
|
||||
|
||||
let response: Result<String, reqwest::Error> = rt.block_on(async {
|
||||
let unmatched_response: Result<String, reqwest::Error> = rt.block_on(async {
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(time::Duration::from_secs(seconds_until_timeout))
|
||||
.timeout(time::Duration::from_secs(seconds_until_timeout as u64))
|
||||
.user_agent("Mozilla/5.0")
|
||||
.build()
|
||||
.unwrap();
|
||||
|
@ -86,9 +86,25 @@ impl NovaClient {
|
|||
Ok(text)
|
||||
});
|
||||
|
||||
match response {
|
||||
Ok(response) => Ok(response),
|
||||
Err(response) => Err(pyo3::exceptions::PyRuntimeError::new_err(response.to_string()))
|
||||
let reponse = match unmatched_response {
|
||||
Ok(unmatched_response) => Ok(unmatched_response),
|
||||
Err(unmatched_response) => Err(pyo3::exceptions::PyRuntimeError::new_err(unmatched_response.to_string()))
|
||||
}.unwrap();
|
||||
|
||||
if endpoint == Endpoints::ChatCompletion {
|
||||
let final_reponse = Python::with_gil(|py| {
|
||||
let reponse = ChatResponse::new(reponse);
|
||||
reponse.into_py(py)
|
||||
});
|
||||
|
||||
return Ok(final_reponse);
|
||||
|
||||
} else {
|
||||
let final_reponse = Python::with_gil(|py| {
|
||||
reponse.into_py(py)
|
||||
});
|
||||
|
||||
return Ok(final_reponse);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,9 +197,36 @@ impl NovaClient {
|
|||
fn get_invalid_model_error() -> PyErr {
|
||||
pyo3::exceptions::PyValueError::new_err("Invalid model")
|
||||
}
|
||||
}
|
||||
|
||||
fn get_request_failed_error() -> PyErr {
|
||||
pyo3::exceptions::PyRuntimeError::new_err("Request failed for unknown reasons.")
|
||||
#[pyclass(module = "nova_python", frozen)]
|
||||
struct ChatResponse {
|
||||
json: String,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl ChatResponse {
|
||||
#[new]
|
||||
fn new(json: String) -> Self {
|
||||
ChatResponse {
|
||||
json
|
||||
}
|
||||
}
|
||||
|
||||
fn get_message_content(&self) -> PyResult<String> {
|
||||
let json = &self.json;
|
||||
let json: serde_json::Value = serde_json::from_str(json).unwrap();
|
||||
|
||||
let content = json["choices"][0]["message"]["content"].as_str().unwrap();
|
||||
Ok(content.trim().to_string())
|
||||
}
|
||||
|
||||
fn __str__(&self) -> PyResult<String> {
|
||||
Ok(self.json.clone())
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> PyResult<String> {
|
||||
Ok(format!("ChatResponse(json={})", self.json))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue