From 92d6f99dcffa7440dd5c6e214eb371539b8d81d6 Mon Sep 17 00:00:00 2001 From: Abdelaziz Mahdy Date: Sun, 27 Oct 2024 11:23:37 -0300 Subject: [PATCH] adding image picker camera screen to examples --- example/lib/main.dart | 25 ++- ...run_model_by_image_picker_camera_demo.dart | 149 ++++++++++++++++++ example/pubspec.lock | 2 +- 3 files changed, 174 insertions(+), 2 deletions(-) create mode 100644 example/lib/run_model_by_image_picker_camera_demo.dart diff --git a/example/lib/main.dart b/example/lib/main.dart index 24a7a27..564aa7b 100644 --- a/example/lib/main.dart +++ b/example/lib/main.dart @@ -1,6 +1,7 @@ import 'package:flutter/material.dart'; import 'package:pytorch_lite_example/run_model_by_camera_demo.dart'; import 'package:pytorch_lite_example/run_model_by_image_demo.dart'; +import 'package:pytorch_lite_example/run_model_by_image_picker_camera_demo.dart'; Future main() async { runApp(const ChooseDemo()); @@ -60,7 +61,27 @@ class _ChooseDemoState extends State { color: Colors.white, ), ), - ) + ), + TextButton( + onPressed: () => { + Navigator.push( + context, + MaterialPageRoute( + builder: (context) => + const RunModelByImagePickerCameraDemo(), + ), + ) + }, + style: TextButton.styleFrom( + backgroundColor: Colors.blue, + ), + child: const Text( + "Run Model with ImagePicker Camera", + style: TextStyle( + color: Colors.white, + ), + ), + ), ], ), ); @@ -69,3 +90,5 @@ class _ChooseDemoState extends State { ); } } + + diff --git a/example/lib/run_model_by_image_picker_camera_demo.dart b/example/lib/run_model_by_image_picker_camera_demo.dart new file mode 100644 index 0000000..52df5ec --- /dev/null +++ b/example/lib/run_model_by_image_picker_camera_demo.dart @@ -0,0 +1,149 @@ +import 'dart:io'; +import 'dart:typed_data'; +import 'package:flutter/material.dart'; +import 'package:image_picker/image_picker.dart'; +import 'package:pytorch_lite/pytorch_lite.dart'; +import 'package:pytorch_lite_example/ui/box_widget.dart'; + +class RunModelByImagePickerCameraDemo extends StatefulWidget { + const RunModelByImagePickerCameraDemo({Key? key}) : super(key: key); + + @override + _RunModelByImagePickerCameraDemoState createState() => + _RunModelByImagePickerCameraDemoState(); +} + +class _RunModelByImagePickerCameraDemoState + extends State { + List? objectDetectionResults; + String? classificationResult; + Duration? objectDetectionInferenceTime; + Duration? classificationInferenceTime; + File? _image; + ModelObjectDetection? _objectModel; + ClassificationModel? _imageModel; + bool _isLoading = false; // Add loading state + + @override + void initState() { + super.initState(); + loadModel(); + } + + Future loadModel() async { + String pathImageModel = "assets/models/model_classification.pt"; + String pathObjectDetectionModel = "assets/models/yolov5s.torchscript"; + try { + _imageModel = await PytorchLite.loadClassificationModel( + pathImageModel, 224, 224, 1000, // Adjust as needed + labelPath: "assets/labels/label_classification_imageNet.txt", + ); + _objectModel = await PytorchLite.loadObjectDetectionModel( + pathObjectDetectionModel, + 80, + 640, + 640, + labelPath: "assets/labels/labels_objectDetection_Coco.txt", + ); + } catch (e) { + print("Error loading model: $e"); + } + } + + Future runModels() async { + setState(() => _isLoading = true); + + final ImagePicker picker = ImagePicker(); + final XFile? pickedImage = + await picker.pickImage(source: ImageSource.camera); + if (pickedImage == null) { + setState(() => _isLoading = false); + return; + } + + File image = File(pickedImage.path); + Uint8List imageBytes = await image.readAsBytes(); // Read bytes once + + // Run both models concurrently + final results = await Future.wait([ + () async { + Stopwatch stopwatch = Stopwatch()..start(); + try { + return await _imageModel?.getImagePrediction(imageBytes); + } catch (e) { + print("Error during classification: $e"); + return null; // or handle the error as needed + } finally { + classificationInferenceTime = stopwatch.elapsed; + } + }(), + () async { + Stopwatch stopwatch = Stopwatch()..start(); + try { + return await _objectModel?.getImagePrediction( + imageBytes, + minimumScore: 0.1, + iOUThreshold: 0.3, + ); + } catch (e) { + print("Error during object detection: $e"); + return null; // or handle the error as needed + } finally { + objectDetectionInferenceTime = stopwatch.elapsed; + } + }(), + ]); + + classificationResult = results[0] as String?; + objectDetectionResults = results[1] as List?; + + setState(() { + _image = image; + _isLoading = false; + }); + } + + @override + Widget build(BuildContext context) { + return Scaffold( + appBar: AppBar(title: const Text('Run Models')), + body: Center( + child: _isLoading + ? const CircularProgressIndicator() // Show loading indicator + : Column( + mainAxisAlignment: MainAxisAlignment.center, + children: [ + if (_image != null) ...[ + SizedBox( + height: MediaQuery.sizeOf(context).height * 0.5, + child: Padding( + padding: const EdgeInsets.all(20), + child: _objectModel!.renderBoxesOnImage( + _image!, objectDetectionResults ?? []), + ), + ), + const SizedBox(height: 20), + Text( + "Classification Result: ${classificationResult ?? "N/A"}", + style: const TextStyle(fontSize: 16), + ), + Text( + "Classification Time: ${classificationInferenceTime?.inMilliseconds ?? "N/A"} ms", + style: const TextStyle(fontSize: 16), + ), + Text( + "Object Detection Time: ${objectDetectionInferenceTime?.inMilliseconds ?? "N/A"} ms", + style: const TextStyle(fontSize: 16), + ), + const SizedBox(height: 20), + ], + ElevatedButton( + onPressed: runModels, + child: const Text('Take Photo & Run Models'), + ), + ], + ), + ), + ); + } +} diff --git a/example/pubspec.lock b/example/pubspec.lock index ba61e8c..ac671ad 100644 --- a/example/pubspec.lock +++ b/example/pubspec.lock @@ -469,7 +469,7 @@ packages: path: ".." relative: true source: path - version: "4.2.7" + version: "4.3.0" sky_engine: dependency: transitive description: flutter