From 3b7993024611db63d9cd05aa442df21e45e6fe5e Mon Sep 17 00:00:00 2001 From: johnathanchiu Date: Fri, 27 Sep 2024 15:55:46 -0400 Subject: [PATCH] [WIP] start setting up model training code --- model/seg/data.py | 78 +++++++++++++++++++++++++++++++++++++++++++++++ model/seg/rseg.py | 21 +++++++++++++ model/train.py | 11 +++++++ 3 files changed, 110 insertions(+) create mode 100644 model/seg/data.py create mode 100644 model/seg/rseg.py create mode 100644 model/train.py diff --git a/model/seg/data.py b/model/seg/data.py new file mode 100644 index 0000000..e3c739e --- /dev/null +++ b/model/seg/data.py @@ -0,0 +1,78 @@ +from torch.utils.data import Dataset +from torchvision import transforms +from PIL import Image +import os +import numpy as np + +# Define a transform to preprocess the images +transform = transforms.Compose( + [ + transforms.RandomCrop(128), # Randomly crop images to 128x128 + transforms.ToTensor(), # Convert images to tensor + ] +) + + +# Create a custom dataset class +class DocumentSegmentationDataset(Dataset): + def __init__(self, root_dir, bbox_data, transform=None): + self.root_dir = root_dir + self.bbox_data = bbox_data # Dictionary mapping image names to bounding boxes + self.transform = transform + self.image_files = os.listdir(root_dir) # List all image files in the directory + + def __len__(self): + return len(self.image_files) # Return the total number of images + + def __getitem__(self, idx): + img_name = os.path.join( + self.root_dir, self.image_files[idx] + ) # Get image file path + image = Image.open(img_name) # Open the image + + if self.transform: + image = self.transform(image) # Apply transformations if any + + # Assuming the model expects strips of the image + strips, labels = self.create_strips_and_labels( + image, img_name + ) # Create strips and labels + + return strips, labels # Return the strips and their corresponding labels + + def create_strips_and_labels(self, image, img_name): + # Convert image to numpy array for processing + image_array = np.array(image) + height, width = image_array.shape[:2] + strip_height = 32 # Define the height of each strip + strips = [] + labels = [] + + # Get bounding boxes for the current image + bboxes = self.bbox_data.get( + os.path.basename(img_name), [] + ) # Get bounding boxes for the image + + # Create strips from the image + for y in range(0, height, strip_height): + strip = image_array[y : y + strip_height, :] # Get a strip + strips.append(strip) + + # Check if any bounding box intersects with the current strip + label = self.check_intersection(bboxes, y, strip_height) + labels.append(label) + + return ( + strips, + labels, + ) # Return the list of strips and their corresponding labels + + def check_intersection(self, bboxes, y, strip_height): + # Check if any bounding box intersects with the strip + for bbox in bboxes: + # bbox format: [x_min, y_min, x_max, y_max] + if ( + bbox[1] < y + strip_height and bbox[3] > y + ): # Check for vertical intersection + return 1 # There is a page break + return 0 # No page break diff --git a/model/seg/rseg.py b/model/seg/rseg.py new file mode 100644 index 0000000..8437bcc --- /dev/null +++ b/model/seg/rseg.py @@ -0,0 +1,21 @@ +import pytorch_lightning as pl +import torch.nn as nn +import torch + + +class SegmentationModel(pl.LightningModule): + def __init__(self): + super(SegmentationModel, self).__init__() + self.layer = nn.Linear(10, 1) # Example layer + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch): + x, y = batch + y_hat = self(x) + loss = nn.functional.mse_loss(y_hat, y) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) diff --git a/model/train.py b/model/train.py new file mode 100644 index 0000000..ddba4c9 --- /dev/null +++ b/model/train.py @@ -0,0 +1,11 @@ +import torch + + +# Sample data +x = torch.randn(100, 10) +y = torch.randn(100, 1) + +# Training +model = SimpleModel() +trainer = pl.Trainer(max_epochs=5) +trainer.fit(model, train_loader)