Skip to content

Commit

Permalink
prodia.ts: inpainting + sdxl
Browse files Browse the repository at this point in the history
  • Loading branch information
montyanderson committed Oct 10, 2023
1 parent 0cafc44 commit 4baca3f
Showing 1 changed file with 80 additions and 14 deletions.
94 changes: 80 additions & 14 deletions prodia.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ export type ProdiaGenerateRequest = {
aspect_ratio?: "square" | "portrait" | "landscape";
};

export type ProdiaTransformRequest = {
imageUrl: string;
type ImageInput = { imageUrl: string } | { imageData: string };

export type ProdiaTransformRequest = ImageInput & {
prompt: string;
model?: string;
denoising_strength?: number;
Expand All @@ -43,8 +44,7 @@ export type ProdiaTransformRequest = {
sampler?: string;
};

export type ProdiaControlnetRequest = {
imageUrl: string;
export type ProdiaControlnetRequest = ImageInput & {
controlnet_model: string;
controlnet_module?: string;
threshold_a?: number;
Expand All @@ -61,6 +61,36 @@ export type ProdiaControlnetRequest = {
height?: number;
};

type MaskInput = { maskUrl: string } | { maskData: string };

export type ProdiaInpaintingRequest = ImageInput &
MaskInput & {
prompt: string;
model?: string;
denoising_strength?: number;
negative_prompt?: string;
steps?: number;
cfg_scale?: number;
seed?: number;
upscale?: boolean;
mask_blur: number;
inpainting_fill: number;
inpainting_mask_invert: number;
inpainting_full_res: string;
sampler?: string;
};

export type ProdiaXlGenerateRequest = {
prompt: string;
model?: string;
negative_prompt?: string;
steps?: number;
cfg_scale?: number;
seed?: number;
upscale?: boolean;
sampler?: string;
};

/* Constructor Definions */

export type Prodia = ReturnType<typeof createProdia>;
Expand All @@ -74,17 +104,17 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {
const base = _base || "https://api.prodia.com/v1";

const headers = {
"X-Prodia-Key": apiKey,
"X-Prodia-Key": apiKey
};

const generate = async (params: ProdiaGenerateRequest) => {
const response = await fetch(`${base}/sd/generate`, {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json",
"Content-Type": "application/json"
},
body: JSON.stringify(params),
body: JSON.stringify(params)
});

if (response.status !== 200) {
Expand All @@ -99,9 +129,9 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json",
"Content-Type": "application/json"
},
body: JSON.stringify(params),
body: JSON.stringify(params)
});

if (response.status !== 200) {
Expand All @@ -116,9 +146,43 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json",
"Content-Type": "application/json"
},
body: JSON.stringify(params)
});

if (response.status !== 200) {
throw new Error(`Bad Prodia Response: ${response.status}`);
}

return (await response.json()) as ProdiaJobQueued;
};

const inpainting = async (params: ProdiaInpaintingRequest) => {
const response = await fetch(`${base}/sd/inpainting`, {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json"
},
body: JSON.stringify(params)
});

if (response.status !== 200) {
throw new Error(`Bad Prodia Response: ${response.status}`);
}

return (await response.json()) as ProdiaJobQueued;
};

const xlGenerate = async (params: ProdiaXlGenerateRequest) => {
const response = await fetch(`${base}/sdxl/generate`, {
method: "POST",
headers: {
...headers,
"Content-Type": "application/json"
},
body: JSON.stringify(params),
body: JSON.stringify(params)
});

if (response.status !== 200) {
Expand All @@ -130,7 +194,7 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {

const getJob = async (jobId: string) => {
const response = await fetch(`${base}/job/${jobId}`, {
headers,
headers
});

if (response.status !== 200) {
Expand All @@ -157,7 +221,7 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {

const listModels = async () => {
const response = await fetch(`${base}/models/list`, {
headers,
headers
});

if (response.status !== 200) {
Expand All @@ -171,8 +235,10 @@ export const createProdia = ({ apiKey, base: _base }: CreateProdiaOptions) => {
generate,
transform,
controlnet,
inpainting,
xlGenerate,
wait,
getJob,
listModels,
listModels
};
};

0 comments on commit 4baca3f

Please sign in to comment.