-
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e6b77ee
commit 5cf3066
Showing
2 changed files
with
40 additions
and
173 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,174 +1,38 @@ | ||
#!/usr/bin/env python | ||
|
||
DEBUG = False | ||
|
||
if DEBUG: | ||
# This code only exists to help us visually inspect the images. | ||
# It's in an `if DEBUG:` block to illustrate that we don't need it for our code to work. | ||
from PIL import Image | ||
import numpy as np | ||
|
||
def read_image(path): | ||
return np.asarray(Image.open(path).convert('L')) | ||
|
||
def write_image(image, path): | ||
img = Image.fromarray(np.array(image), 'L') | ||
img.save(path) | ||
|
||
|
||
DATA_DIR = 'data/' | ||
TEST_DIR = 'test/' | ||
DATASET = 'fashion-mnist' # `'mnist'` or `'fashion-mnist'` | ||
TEST_DATA_FILENAME = DATA_DIR + DATASET + '/t10k-images-idx3-ubyte' | ||
TEST_LABELS_FILENAME = DATA_DIR + DATASET + '/t10k-labels-idx1-ubyte' | ||
TRAIN_DATA_FILENAME = DATA_DIR + DATASET + '/train-images-idx3-ubyte' | ||
TRAIN_LABELS_FILENAME = DATA_DIR + DATASET + '/train-labels-idx1-ubyte' | ||
|
||
|
||
def bytes_to_int(byte_data): | ||
return int.from_bytes(byte_data, 'big') | ||
|
||
|
||
def read_images(filename, n_max_images=None): | ||
images = [] | ||
with open(filename, 'rb') as f: | ||
_ = f.read(4) # magic number | ||
n_images = bytes_to_int(f.read(4)) | ||
if n_max_images: | ||
n_images = n_max_images | ||
n_rows = bytes_to_int(f.read(4)) | ||
n_columns = bytes_to_int(f.read(4)) | ||
for image_idx in range(n_images): | ||
image = [] | ||
for row_idx in range(n_rows): | ||
row = [] | ||
for col_idx in range(n_columns): | ||
pixel = f.read(1) | ||
row.append(pixel) | ||
image.append(row) | ||
images.append(image) | ||
return images | ||
|
||
|
||
def read_labels(filename, n_max_labels=None): | ||
labels = [] | ||
with open(filename, 'rb') as f: | ||
_ = f.read(4) # magic number | ||
n_labels = bytes_to_int(f.read(4)) | ||
if n_max_labels: | ||
n_labels = n_max_labels | ||
for label_idx in range(n_labels): | ||
label = bytes_to_int(f.read(1)) | ||
labels.append(label) | ||
return labels | ||
|
||
|
||
def flatten_list(l): | ||
return [pixel for sublist in l for pixel in sublist] | ||
|
||
|
||
def extract_features(X): | ||
return [flatten_list(sample) for sample in X] | ||
|
||
|
||
def dist(x, y): | ||
""" | ||
Returns the Euclidean distance between vectors `x` and `y`. | ||
""" | ||
return sum( | ||
[ | ||
(bytes_to_int(x_i) - bytes_to_int(y_i)) ** 2 | ||
for x_i, y_i in zip(x, y) | ||
] | ||
) ** (0.5) | ||
|
||
|
||
def get_training_distances_for_test_sample(X_train, test_sample): | ||
return [dist(train_sample, test_sample) for train_sample in X_train] | ||
|
||
|
||
def get_most_frequent_element(l): | ||
return max(l, key=l.count) | ||
|
||
|
||
def knn(X_train, y_train, X_test, k=3): | ||
y_pred = [] | ||
for test_sample_idx, test_sample in enumerate(X_test): | ||
print(test_sample_idx, end=' ', flush=True) | ||
training_distances = get_training_distances_for_test_sample( | ||
X_train, test_sample | ||
import fitz, os # PyMuPDF | ||
from .models import UploadedFile, Course | ||
from datetime import date | ||
from django.conf import settings | ||
|
||
from django.core.files.storage import FileSystemStorage | ||
|
||
""" | ||
OCR File Creation Pipeline: | ||
- When a file is uploaded to the filesystem, conduct OCR to extract and create text file | ||
- Check for name availability with created text file | ||
- Store text file on filesystem and database | ||
""" | ||
class OCR: | ||
def extract_text_from_pdf(fType, course, fileName, fileUrl): | ||
# Adding file to filesystem | ||
fs = FileSystemStorage() | ||
removeExt = os.path.splitext(fileName)[0] | ||
txt_file_name = removeExt + ".txt" | ||
txt_file_path = os.path.join(settings.MEDIA_ROOT, txt_file_name) | ||
text = '' | ||
pdf_document = fitz.open(fileUrl) | ||
for page_num in range(pdf_document.page_count): | ||
page = pdf_document[page_num] | ||
text += page.get_text() | ||
with open(txt_file_path, "w", encoding="utf-8") as file: | ||
file.write(text) | ||
fs.save(fileName, txt_file_path) # Retrieve the filename | ||
|
||
# Save the text content to the database using UploadedFile model | ||
uploaded_file = UploadedFile.objects.create( | ||
filename=txt_file_name, | ||
file_dir=txt_file_path, | ||
course=Course.objects.get(name=course), | ||
date_uploaded=date.today, | ||
flag=fType, | ||
) | ||
sorted_distance_indices = [ | ||
pair[0] | ||
for pair in sorted( | ||
enumerate(training_distances), | ||
key=lambda x: x[1] | ||
) | ||
] | ||
candidates = [ | ||
y_train[idx] | ||
for idx in sorted_distance_indices[:k] | ||
] | ||
top_candidate = get_most_frequent_element(candidates) | ||
y_pred.append(top_candidate) | ||
print() | ||
return y_pred | ||
|
||
|
||
def get_garment_from_label(label): | ||
return [ | ||
'T-shirt/top', | ||
'Trouser', | ||
'Pullover', | ||
'Dress', | ||
'Coat', | ||
'Sandal', | ||
'Shirt', | ||
'Sneaker', | ||
'Bag', | ||
'Ankle boot', | ||
][label] | ||
|
||
|
||
def main(): | ||
n_train = 1000 | ||
n_test = 10 | ||
k = 7 | ||
print(f'Dataset: {DATASET}') | ||
print(f'n_train: {n_train}') | ||
print(f'n_test: {n_test}') | ||
print(f'k: {k}') | ||
X_train = read_images(TRAIN_DATA_FILENAME, n_train) | ||
y_train = read_labels(TRAIN_LABELS_FILENAME, n_train) | ||
X_test = read_images(TEST_DATA_FILENAME, n_test) | ||
y_test = read_labels(TEST_LABELS_FILENAME, n_test) | ||
|
||
if DEBUG: | ||
# Write some images out just so we can see them visually. | ||
for idx, test_sample in enumerate(X_test): | ||
write_image(test_sample, f'{TEST_DIR}{idx}.png') | ||
# Load in the `our_test.png` we drew ourselves! | ||
# X_test = [read_image(f'{DATA_DIR}our_test.png')] | ||
# y_test = [5] | ||
|
||
X_train = extract_features(X_train) | ||
X_test = extract_features(X_test) | ||
|
||
y_pred = knn(X_train, y_train, X_test, k) | ||
|
||
accuracy = sum([ | ||
int(y_pred_i == y_test_i) | ||
for y_pred_i, y_test_i | ||
in zip(y_pred, y_test) | ||
]) / len(y_test) | ||
|
||
if DATASET == 'fashion-mnist': | ||
garments_pred = [ | ||
get_garment_from_label(label) | ||
for label in y_pred | ||
] | ||
print(f'Predicted garments: {garments_pred}') | ||
else: | ||
print(f'Predicted labels: {y_pred}') | ||
|
||
print(f'Accuracy: {accuracy * 100}%') | ||
uploaded_file.save() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters