Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 211 additions & 71 deletions src-tauri/crates/providers/src/openai_images.rs

Large diffs are not rendered by default.

63 changes: 62 additions & 1 deletion src-tauri/src/commands/drawing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use aqbot_core::repo::drawing::{
use aqbot_core::repo::stored_file::StoredFile;
use aqbot_core::types::{ProviderConfig, ProviderProxyConfig, ProviderType};
use aqbot_providers::openai_images::{
ImageEditRequest, ImageEditTransferMode, ImageGenerateRequest, ImageUpload, OpenAIImagesClient,
ImageEditImageFormat, ImageEditRequest, ImageEditTransferMode, ImageGenerateRequest, ImageUpload, OpenAIImagesClient,
};
use aqbot_providers::{resolve_base_url_for_type, ProviderRequestContext};
use base64::Engine;
Expand Down Expand Up @@ -38,7 +38,15 @@ pub struct DrawingGenerateInput {
#[serde(default)]
pub reference_image_mode: DrawingReferenceImageMode,
#[serde(default)]
pub reference_image_format: DrawingReferenceImageFormat,
#[serde(default)]
pub reference_image_param_name: String,
#[serde(default)]
pub reference_file_ids: Vec<String>,
#[serde(default)]
pub generation_api_path: String,
#[serde(default)]
pub edit_api_path: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -56,7 +64,15 @@ pub struct DrawingEditInput {
#[serde(default)]
pub reference_image_mode: DrawingReferenceImageMode,
#[serde(default)]
pub reference_image_format: DrawingReferenceImageFormat,
#[serde(default)]
pub reference_image_param_name: String,
#[serde(default)]
pub reference_file_ids: Vec<String>,
#[serde(default)]
pub generation_api_path: String,
#[serde(default)]
pub edit_api_path: String,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand All @@ -75,7 +91,15 @@ pub struct DrawingMaskEditInput {
#[serde(default)]
pub reference_image_mode: DrawingReferenceImageMode,
#[serde(default)]
pub reference_image_format: DrawingReferenceImageFormat,
#[serde(default)]
pub reference_image_param_name: String,
#[serde(default)]
pub reference_file_ids: Vec<String>,
#[serde(default)]
pub generation_api_path: String,
#[serde(default)]
pub edit_api_path: String,
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
Expand All @@ -100,6 +124,28 @@ impl From<DrawingReferenceImageMode> for ImageEditTransferMode {
}
}

#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum DrawingReferenceImageFormat {
Object,
String,
}

impl Default for DrawingReferenceImageFormat {
fn default() -> Self {
Self::Object
}
}

impl From<DrawingReferenceImageFormat> for ImageEditImageFormat {
fn from(value: DrawingReferenceImageFormat) -> Self {
match value {
DrawingReferenceImageFormat::Object => ImageEditImageFormat::Object,
DrawingReferenceImageFormat::String => ImageEditImageFormat::String,
}
}
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DrawingUploadInput {
pub data: String,
Expand Down Expand Up @@ -189,6 +235,9 @@ pub async fn generate_drawing_images(
)
.await?;

let generation_path = if input.generation_api_path.is_empty() { None } else { Some(input.generation_api_path.as_str()) };
let edit_path = if input.edit_api_path.is_empty() { None } else { Some(input.edit_api_path.as_str()) };

let result = if input.reference_file_ids.is_empty() {
OpenAIImagesClient::new()
.generate(
Expand All @@ -203,6 +252,7 @@ pub async fn generate_drawing_images(
background: input.background.clone(),
output_compression: input.output_compression,
},
generation_path,
)
.await
} else {
Expand All @@ -220,9 +270,12 @@ pub async fn generate_drawing_images(
background: input.background.clone(),
output_compression: input.output_compression,
transfer_mode: input.reference_image_mode.into(),
image_format: input.reference_image_format.into(),
image_param_name: input.reference_image_param_name.clone(),
images: uploads,
mask: None,
},
edit_path,
)
.await
};
Expand Down Expand Up @@ -263,6 +316,7 @@ pub async fn edit_drawing_image(
None,
)
.await?;
let edit_path = if input.edit_api_path.is_empty() { None } else { Some(input.edit_api_path.as_str()) };
let mut uploads = vec![load_drawing_image_upload(&state, &source).await?];
uploads.extend(load_reference_uploads(&state, &input.reference_file_ids).await?);
let result = OpenAIImagesClient::new()
Expand All @@ -278,9 +332,12 @@ pub async fn edit_drawing_image(
background: input.background.clone(),
output_compression: input.output_compression,
transfer_mode: input.reference_image_mode.into(),
image_format: input.reference_image_format.into(),
image_param_name: input.reference_image_param_name.clone(),
images: uploads,
mask: None,
},
edit_path,
)
.await;

Expand Down Expand Up @@ -330,6 +387,7 @@ pub async fn edit_drawing_image_with_mask(
Some(input.mask_file_id.clone()),
)
.await?;
let edit_path = if input.edit_api_path.is_empty() { None } else { Some(input.edit_api_path.as_str()) };
let mut uploads = vec![load_drawing_image_upload(&state, &source).await?];
uploads.extend(load_reference_uploads(&state, &input.reference_file_ids).await?);
let mask = Some(load_stored_file_upload(&state, &mask_file).await?);
Expand All @@ -346,9 +404,12 @@ pub async fn edit_drawing_image_with_mask(
background: input.background.clone(),
output_compression: input.output_compression,
transfer_mode: input.reference_image_mode.into(),
image_format: input.reference_image_format.into(),
image_param_name: input.reference_image_param_name.clone(),
images: uploads,
mask,
},
edit_path,
)
.await;

Expand Down
4 changes: 4 additions & 0 deletions src/components/drawing/DrawingComposer.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,12 @@ export function DrawingComposer({ settings, prompt, onPromptChange, onHeightChan
background: settings.background,
output_compression: settings.outputCompression,
reference_image_mode: settings.referenceImageMode,
reference_image_format: settings.referenceImageFormat,
reference_image_param_name: settings.referenceImageParamName,
n: settings.n,
reference_file_ids: references.map((item) => item.id),
generation_api_path: settings.generationApiPath,
edit_api_path: settings.editApiPath,
};
onPromptChange('');
if (editSourceImage && editMaskFileId) {
Expand Down
Loading