Skip to content

Commit

Permalink
Merge pull request #118 from dongri/refactor-function-types
Browse files Browse the repository at this point in the history
refactoring function
  • Loading branch information
dongri authored Oct 15, 2024
2 parents 34aae91 + 5867baa commit e46412c
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 58 deletions.
11 changes: 6 additions & 5 deletions examples/function_call.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use openai_api_rs::v1::api::OpenAIClient;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
use openai_api_rs::v1::common::GPT4_O;
use openai_api_rs::v1::types;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::{env, vec};
Expand All @@ -21,8 +22,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut properties = HashMap::new();
properties.insert(
"coin".to_string(),
Box::new(chat_completion::JSONSchemaDefine {
schema_type: Some(chat_completion::JSONSchemaType::String),
Box::new(types::JSONSchemaDefine {
schema_type: Some(types::JSONSchemaType::String),
description: Some("The cryptocurrency to get the price of".to_string()),
..Default::default()
}),
Expand All @@ -40,11 +41,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.tools(vec![chat_completion::Tool {
r#type: chat_completion::ToolType::Function,
function: chat_completion::Function {
function: types::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
parameters: types::FunctionParameters {
schema_type: types::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
Expand Down
11 changes: 6 additions & 5 deletions examples/function_call_role.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use openai_api_rs::v1::api::OpenAIClient;
use openai_api_rs::v1::chat_completion::{self, ChatCompletionRequest};
use openai_api_rs::v1::common::GPT4_O;
use openai_api_rs::v1::types;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::{env, vec};
Expand All @@ -21,8 +22,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let mut properties = HashMap::new();
properties.insert(
"coin".to_string(),
Box::new(chat_completion::JSONSchemaDefine {
schema_type: Some(chat_completion::JSONSchemaType::String),
Box::new(types::JSONSchemaDefine {
schema_type: Some(types::JSONSchemaType::String),
description: Some("The cryptocurrency to get the price of".to_string()),
..Default::default()
}),
Expand All @@ -40,11 +41,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
)
.tools(vec![chat_completion::Tool {
r#type: chat_completion::ToolType::Function,
function: chat_completion::Function {
function: types::Function {
name: String::from("get_coin_price"),
description: Some(String::from("Get the price of a cryptocurrency")),
parameters: chat_completion::FunctionParameters {
schema_type: chat_completion::JSONSchemaType::Object,
parameters: types::FunctionParameters {
schema_type: types::JSONSchemaType::Object,
properties: Some(properties),
required: Some(vec![String::from("coin")]),
},
Expand Down
43 changes: 42 additions & 1 deletion src/v1/assistant.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

use super::types;
use crate::impl_builder_methods;

#[derive(Debug, Serialize, Clone)]
Expand Down Expand Up @@ -56,13 +57,53 @@ pub struct AssistantObject {
pub model: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
pub tools: Vec<HashMap<String, String>>,
pub tools: Vec<Tools>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_resources: Option<ToolResource>,
pub metadata: Option<HashMap<String, String>>,
pub headers: Option<HashMap<String, String>>,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(tag = "type")]
#[serde(rename_all = "snake_case")]
pub enum Tools {
CodeInterpreter,
FileSearch(ToolsFileSearch),
Function(ToolsFunction),
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ToolsFileSearch {
#[serde(skip_serializing_if = "Option::is_none")]
pub file_search: Option<ToolsFileSearchObject>,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ToolsFunction {
pub function: types::Function,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ToolsFileSearchObject {
pub max_num_results: Option<u8>,
pub ranking_options: Option<FileSearchRankingOptions>,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct FileSearchRankingOptions {
pub ranker: Option<FileSearchRanker>,
pub score_threshold: Option<f32>,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub enum FileSearchRanker {
#[serde(rename = "auto")]
Auto,
#[serde(rename = "default_2024_08_21")]
Default2024_08_21,
}

#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ToolResource {
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down
51 changes: 4 additions & 47 deletions src/v1/chat_completion.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::{common, types};
use crate::impl_builder_methods;
use crate::v1::common;

use serde::de::{self, MapAccess, SeqAccess, Visitor};
use serde::ser::SerializeMap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
Expand Down Expand Up @@ -185,6 +186,7 @@ impl<'de> Deserialize<'de> for Content {
deserializer.deserialize_any(ContentVisitor)
}
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum ContentType {
Expand Down Expand Up @@ -251,51 +253,6 @@ pub struct ChatCompletionResponse {
pub headers: Option<HashMap<String, String>>,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: FunctionParameters,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum JSONSchemaType {
Object,
Number,
String,
Array,
Null,
Boolean,
}

#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialEq, Eq)]
pub struct JSONSchemaDefine {
#[serde(rename = "type")]
pub schema_type: Option<JSONSchemaType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub items: Option<Box<JSONSchemaDefine>>,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub schema_type: JSONSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
}

#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
#[allow(non_camel_case_types)]
pub enum FinishReason {
Expand Down Expand Up @@ -352,7 +309,7 @@ where
#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct Tool {
pub r#type: ToolType,
pub function: Function,
pub function: types::Function,
}

#[derive(Debug, Deserialize, Serialize, Copy, Clone, PartialEq, Eq)]
Expand Down
1 change: 1 addition & 0 deletions src/v1/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
pub mod common;
pub mod error;
pub mod types;

pub mod audio;
pub mod batch;
Expand Down
47 changes: 47 additions & 0 deletions src/v1/types.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct Function {
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub parameters: FunctionParameters,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub schema_type: JSONSchemaType,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
}

#[derive(Debug, Deserialize, Serialize, Clone, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum JSONSchemaType {
Object,
Number,
String,
Array,
Null,
Boolean,
}

#[derive(Debug, Deserialize, Serialize, Clone, Default, PartialEq, Eq)]
pub struct JSONSchemaDefine {
#[serde(rename = "type")]
pub schema_type: Option<JSONSchemaType>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub properties: Option<HashMap<String, Box<JSONSchemaDefine>>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub items: Option<Box<JSONSchemaDefine>>,
}

0 comments on commit e46412c

Please sign in to comment.