Skip to content

Commit

Permalink
[WIP] start setting up model training code
Browse files Browse the repository at this point in the history
  • Loading branch information
johnathanchiu committed Sep 27, 2024
1 parent 5404797 commit 3b79930
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
78 changes: 78 additions & 0 deletions model/seg/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import os
import numpy as np

# Define a transform to preprocess the images
transform = transforms.Compose(
[
transforms.RandomCrop(128), # Randomly crop images to 128x128
transforms.ToTensor(), # Convert images to tensor
]
)


# Create a custom dataset class
class DocumentSegmentationDataset(Dataset):
def __init__(self, root_dir, bbox_data, transform=None):
self.root_dir = root_dir
self.bbox_data = bbox_data # Dictionary mapping image names to bounding boxes
self.transform = transform
self.image_files = os.listdir(root_dir) # List all image files in the directory

def __len__(self):
return len(self.image_files) # Return the total number of images

def __getitem__(self, idx):
img_name = os.path.join(
self.root_dir, self.image_files[idx]
) # Get image file path
image = Image.open(img_name) # Open the image

if self.transform:
image = self.transform(image) # Apply transformations if any

# Assuming the model expects strips of the image
strips, labels = self.create_strips_and_labels(
image, img_name
) # Create strips and labels

return strips, labels # Return the strips and their corresponding labels

def create_strips_and_labels(self, image, img_name):
# Convert image to numpy array for processing
image_array = np.array(image)
height, width = image_array.shape[:2]
strip_height = 32 # Define the height of each strip
strips = []
labels = []

# Get bounding boxes for the current image
bboxes = self.bbox_data.get(
os.path.basename(img_name), []
) # Get bounding boxes for the image

# Create strips from the image
for y in range(0, height, strip_height):
strip = image_array[y : y + strip_height, :] # Get a strip
strips.append(strip)

# Check if any bounding box intersects with the current strip
label = self.check_intersection(bboxes, y, strip_height)
labels.append(label)

return (
strips,
labels,
) # Return the list of strips and their corresponding labels

def check_intersection(self, bboxes, y, strip_height):
# Check if any bounding box intersects with the strip
for bbox in bboxes:
# bbox format: [x_min, y_min, x_max, y_max]
if (
bbox[1] < y + strip_height and bbox[3] > y
): # Check for vertical intersection
return 1 # There is a page break
return 0 # No page break
21 changes: 21 additions & 0 deletions model/seg/rseg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytorch_lightning as pl
import torch.nn as nn
import torch


class SegmentationModel(pl.LightningModule):
def __init__(self):
super(SegmentationModel, self).__init__()
self.layer = nn.Linear(10, 1) # Example layer

def forward(self, x):
return self.layer(x)

def training_step(self, batch):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
return loss

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
11 changes: 11 additions & 0 deletions model/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch


# Sample data
x = torch.randn(100, 10)
y = torch.randn(100, 1)

# Training
model = SimpleModel()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_loader)

0 comments on commit 3b79930

Please sign in to comment.