diff --git a/prodia.ts b/prodia.ts index 12c2253..c08a27c 100644 --- a/prodia.ts +++ b/prodia.ts @@ -1,14 +1,9 @@ /* Job Responses */ -type ProdiaJobBase = { job: string }; - -export type ProdiaJobQueued = ProdiaJobBase & { status: "queued" }; -export type ProdiaJobGenerating = ProdiaJobBase & { status: "generating" }; -export type ProdiaJobFailed = ProdiaJobBase & { status: "failed" }; -export type ProdiaJobSucceeded = ProdiaJobBase & { - status: "succeeded"; - imageUrl: string; -}; +export type ProdiaJobQueued = { imageUrl: undefined; status: "queued" }; +export type ProdiaJobGenerating = { imageUrl: undefined; status: "generating" }; +export type ProdiaJobFailed = { imageUrl: undefined; status: "failed" }; +export type ProdiaJobSucceeded = { imageUrl: string; status: "succeeded" }; export type ProdiaJob = | ProdiaJobQueued @@ -30,8 +25,9 @@ export type ProdiaGenerateRequest = { aspect_ratio?: "square" | "portrait" | "landscape"; }; -export type ProdiaTransformRequest = { - imageUrl: string; +type ImageInput = { imageUrl: string } | { imageData: string }; + +export type ProdiaTransformRequest = ImageInput & { prompt: string; model?: string; denoising_strength?: number; @@ -43,8 +39,7 @@ export type ProdiaTransformRequest = { sampler?: string; }; -export type ProdiaControlnetRequest = { - imageUrl: string; +export type ProdiaControlnetRequest = ImageInput & { controlnet_model: string; controlnet_module?: string; threshold_a?: number; @@ -61,6 +56,36 @@ export type ProdiaControlnetRequest = { height?: number; }; +type MaskInput = { maskUrl: string } | { maskData: string }; + +export type ProdiaInpaintingRequest = ImageInput & + MaskInput & { + prompt: string; + model?: string; + denoising_strength?: number; + negative_prompt?: string; + steps?: number; + cfg_scale?: number; + seed?: number; + upscale?: boolean; + mask_blur: number; + inpainting_fill: number; + inpainting_mask_invert: number; + inpainting_full_res: string; + sampler?: string; + }; + +export type ProdiaXlGenerateRequest = { + prompt: string; + model?: string; + negative_prompt?: string; + steps?: number; + cfg_scale?: number; + seed?: number; + upscale?: boolean; + sampler?: string; +}; + /* Constructor Definions */ export type Prodia = ReturnType; @@ -74,7 +99,7 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => { const base = _base || "https://api.prodia.com/v1"; const headers = { - "X-Prodia-Key": apiKey, + "X-Prodia-Key": apiKey }; const generate = async (params: ProdiaGenerateRequest) => { @@ -82,9 +107,9 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => { method: "POST", headers: { ...headers, - "Content-Type": "application/json", + "Content-Type": "application/json" }, - body: JSON.stringify(params), + body: JSON.stringify(params) }); if (response.status !== 200) { @@ -99,9 +124,9 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => { method: "POST", headers: { ...headers, - "Content-Type": "application/json", + "Content-Type": "application/json" }, - body: JSON.stringify(params), + body: JSON.stringify(params) }); if (response.status !== 200) { @@ -116,9 +141,43 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => { method: "POST", headers: { ...headers, - "Content-Type": "application/json", + "Content-Type": "application/json" + }, + body: JSON.stringify(params) + }); + + if (response.status !== 200) { + throw new Error(`Bad Prodia Response: ${response.status}`); + } + + return (await response.json()) as ProdiaJobQueued; + }; + + const inpainting = async (params: ProdiaInpaintingRequest) => { + const response = await fetch(`${base}/sd/inpainting`, { + method: "POST", + headers: { + ...headers, + "Content-Type": "application/json" + }, + body: JSON.stringify(params) + }); + + if (response.status !== 200) { + throw new Error(`Bad Prodia Response: ${response.status}`); + } + + return (await response.json()) as ProdiaJobQueued; + }; + + const xlGenerate = async (params: ProdiaXlGenerateRequest) => { + const response = await fetch(`${base}/sdxl/generate`, { + method: "POST", + headers: { + ...headers, + "Content-Type": "application/json" }, - body: JSON.stringify(params), + body: JSON.stringify(params) }); if (response.status !== 200) { @@ -130,7 +189,7 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => { const getJob = async (jobId: string) => { const response = await fetch(`${base}/job/${jobId}`, { - headers, + headers }); if (response.status !== 200) { @@ -157,7 +216,7 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => { const listModels = async () => { const response = await fetch(`${base}/models/list`, { - headers, + headers }); if (response.status !== 200) { @@ -171,8 +230,10 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => { generate, transform, controlnet, + inpainting, + xlGenerate, wait, getJob, - listModels, + listModels }; };