feat: add support for image-to-image generation in GeminiImageGenerationService

- Enhanced the `generate` method to accept base64 encoded input images and their MIME types.
- Implemented validation for input image and MIME type to ensure proper usage.
- Updated the content construction logic to handle both text-to-image and image-to-image generation scenarios.
This commit is contained in:
Nariman Jelveh
2025-09-15 15:08:06 -07:00
parent 47acb141da
commit f4fc24bce3

View File

@@ -65,11 +65,13 @@ class GeminiImageGenerationService extends BaseService {
* @param {Object} options - Generation options
* @param {Object} options.ratio - Image dimensions ratio object with w/h properties
* @param {string} [options.model='gemini-2.5-flash-image-preview'] - The model to use for generation
* @param {string} [options.input_image] - Base64 encoded input image for image-to-image generation
* @param {string} [options.input_image_mime_type] - MIME type of the input image
* @returns {Promise<string>} URL of the generated image
* @throws {Error} If prompt is not a string or ratio is invalid
*/
async generate(params) {
const { prompt, quality, test_mode, model, ratio } = params;
const { prompt, quality, test_mode, model, ratio, input_image, input_image_mime_type } = params;
if (test_mode) {
return new TypedValue({
@@ -77,10 +79,13 @@ class GeminiImageGenerationService extends BaseService {
content_type: 'image',
}, 'https://puter-sample-data.puter.site/image_example.png');
}
const url = await this.generate(prompt, {
quality,
ratio: ratio || this.constructor.RATIO_SQUARE,
model
model,
input_image,
input_image_mime_type
});
const image = new TypedValue({
@@ -98,6 +103,8 @@ class GeminiImageGenerationService extends BaseService {
async generate(prompt, {
ratio,
model,
input_image,
input_image_mime_type,
}) {
if (typeof prompt !== 'string') {
throw new Error('`prompt` must be a string');
@@ -107,6 +114,19 @@ class GeminiImageGenerationService extends BaseService {
throw new Error('`ratio` must be a valid ratio for model ' + model);
}
// Validate input image if provided
if (input_image && !input_image_mime_type) {
throw new Error('`input_image_mime_type` is required when `input_image` is provided');
}
if (input_image_mime_type && !input_image) {
throw new Error('`input_image` is required when `input_image_mime_type` is provided');
}
if (input_image_mime_type && !this._validate_image_mime_type(input_image_mime_type)) {
throw new Error('`input_image_mime_type` must be a valid image MIME type (image/png, image/jpeg, image/webp)');
}
// Somewhat sane defaults
model = model ?? 'gemini-2.5-flash-image-preview';
@@ -154,9 +174,27 @@ class GeminiImageGenerationService extends BaseService {
// We can charge immediately
await svc_cost.record_cost({ cost: exact_cost });
// Construct the prompt based on whether we have an input image
let contents;
if (input_image && input_image_mime_type) {
// Image-to-image generation
contents = [
{ text: `Generate a picture of dimensions ${parseInt(ratio.w)}x${parseInt(ratio.h)} with the prompt: ${prompt}` },
{
inlineData: {
mimeType: input_image_mime_type,
data: input_image,
},
},
];
} else {
// Text-to-image generation
contents = `Generate a picture of dimensions ${parseInt(ratio.w)}x${parseInt(ratio.h)} with the prompt: ${prompt}`;
}
const response = await this.genAI.models.generateContent({
model: "gemini-2.5-flash-image-preview",
contents: `Generate a picture of dimensions ${parseInt(ratio.w)}x${parseInt(ratio.h)} with the prompt: ${prompt}`,
contents: contents,
});
let url = undefined;
for (const part of response.candidates[0].content.parts) {
@@ -198,6 +236,17 @@ class GeminiImageGenerationService extends BaseService {
const validRatios = this._getValidRatios(model);
return validRatios.includes(ratio);
}
/**
* Validates if the provided MIME type is supported for input images
* @param {string} mimeType - The MIME type to validate
* @returns {boolean} True if the MIME type is supported
* @private
*/
_validate_image_mime_type(mimeType) {
const supportedTypes = ['image/png', 'image/jpeg', 'image/jpg', 'image/webp'];
return supportedTypes.includes(mimeType.toLowerCase());
}
}
module.exports = {