-
-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Does Metalhead.classify
support pretrained models other than VGG19 ?
#30
Comments
On my PC, I'm using Julia1.0.1 with the following libraries
|
I found handle data as using Flux: onecold
using Metalhead: ResNet, VGG19
using Images
# Take the len-by-len square of pixels at the center of image `im`
function center_crop(im, len)
l2 = div(len,2)
adjust = len % 2 == 0 ? 1 : 0
return im[div(end,2)-l2:div(end,2)+l2-adjust,div(end,2)-l2:div(end,2)+l2-adjust]
end
# Resize an image such that its smallest dimension is the given length
function resize_smallest_dimension(im, len)
reduction_factor = len/minimum(size(im)[1:2])
new_size = size(im)
new_size = (
round(Int, size(im,1)*reduction_factor),
round(Int, size(im,2)*reduction_factor),
)
if reduction_factor < 1.0
# Images.jl's imresize() needs to first lowpass the image, it won't do it for us
im = imfilter(im, KernelFactors.gaussian(0.75/reduction_factor), Inner())
end
return imresize(im, new_size)
end
function preprocess(im::AbstractMatrix{<:AbstractRGB})
# Resize such that smallest edge is 256 pixels long
im = resize_smallest_dimension(im, 256)
# Center-crop to 224x224
im = center_crop(im, 224)
# Convert to channel view and normalize (these coefficients taken
# from PyTorch's ImageNet normalization code)
μ = [0.485, 0.456, 0.406]
σ = [0.229, 0.224, 0.225]
im = (channelview(im) .- μ)./σ
# Convert from CHW (Image.jl's channel ordering) to WHCN (Flux.jl's ordering)
# convert eltype as Float64
return Float64.(permutedims(im, (3, 2, 1))[:,:,:,:].*255)
end
const label_txt = expanduser("~/.julia/packages/Metalhead/rGGAv/datasets/meta/ILSVRC_synset_mappings.txt")
imagenet_labels = String[]
for (idx, line) in enumerate(eachline(label_txt))
synset = line[1:9]
label = line[11:end]
push!(imagenet_labels, label)
end
function main()
model = ResNet()
preprocessed = preprocess(load("elephant.jpeg"))
result = imagenet_labels[onecold(model(preprocessed))][1]
@show result
end
main()
|
@terasakisatoshi There is some issue with the pretrained weights of Resnet and Densenet that were imported using ONNX. So I am not surprised by the incorrect results. |
Sorry to bother you too much today, Why don't you load batchnorm parameters in ResNet example? |
IIRC I had discussed this with @ayush1999. The onnx file imported had no batchnorm parameters. That's the reason the resnet doesn't perform as expected. |
When I dump the .bson file, there are values such that:
I believe: bn_b -> batchnorm bias parameter |
Thanks. I will add them as soon as possible. :) |
@ekinakyurek I have updated the Resnet Model. U can checkout the |
I couldn't see newmodels branch. Is it working ? |
Here's the branch |
Maybe there is something wrong on my end but ResNet50 is not giving correct results. A picture of an elephant, for example, is classified as ""Persian cat". VGG19 seems to work fine. There might still be an issue with importing the pre-trained weights? Also, could you coordinate merging the whole thing with all the fixes under the default FluxML/Metalhead.jl? Very much appreciated. |
I recently failed to import ResNet with ONNX.jl, and then succeeded with ONNXmutable.jl. |
The latest release no longer has the I'm closing this issue for now, so that we only have #72 to track the incorrect results. |
Hi, I'm trying to use pretrained model stored in this repository
VGG
,ResNet
,GoogleNet
,DenseNet
and so on.
I wrote the following code to see what happens.
This works fine showing a kind ofelephant as result.
How about the other model e.g. DenseNet ?
also ResNet, GoogleNet show similar error.
Does
Metalhead.classify
support pretrained models other than VGG19 ?The text was updated successfully, but these errors were encountered: