概要
Flutterで機械学習モデルによる画像分類を試してみた。
おおまかな手順は、
- Pytorchでモデルを作成
- tflite形式に変換
- Flutterで実装
開発環境
OS | Windows11 |
---|---|
Editor | VS Code |
Flutter Ver. | 3.0.5 |
Smart Phone | Android |
分類モデルの作成
Pytorchで画像分類のモデルを作成し、tflite形式に変換。
コードは全てGoogle Colabで実行。
Pytorchで分類モデルを作成する
ライブラリのインストール
import os import numpy as np import torch from torchvision import models import torchvision.utils as vutils from torchvision.io import read_image
設定値
batch_size = 1 channel = 3 width = 224 height = 224 path_pth = "/content/model.pth" path_onnx = "/content/model.onnx" path_tf = "/content/tf_model" path_tflite = "/content/model.tflite"
モデルの保存
net = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights)
print(net)
torch.save(net.state_dict(), path_pth)
モデルの検証
weights = models.MobileNet_V3_Small_Weights.DEFAULT net.eval() preprocess = weights.transforms() img = read_image("/content/greyfox.jpg") batch = preprocess(img).unsqueeze(0) img_np = batch.to('cpu').detach().numpy().copy() prediction = net(batch).squeeze(0).softmax(0) class_id = prediction.argmax().item() score = prediction[class_id].item() category_name = weights.meta["categories"][class_id] print(f"{category_name}: {100 * score}%")
実行結果
grey fox: 97.51530885696411%
ラベルデータの作成
category_name = weights.meta["categories"] print(category_name) with open('labels.txt', 'w') as f: for x in category_name: f.write("%s\n" % x)
PytorchモデルをTFLIte形式に変換
Pytroch ⇒ onnx
conver_model = net conver_model.load_state_dict(torch.load(path_pth, map_location='cpu')) conver_model.cpu() conver_model.eval() sample_input = torch.rand(batch_size ,channel, width, height) torch.onnx.export( conver_model, sample_input, path_onnx, opset_version=12, input_names=['input'], output_names=['output'] )
onnx ⇒ TF
!git clone https://github.com/onnx/onnx-tensorflow.git %cd onnx-tensorflow !pip install -e .
import onnx
onnx_model = onnx.load(path_onnx)
from onnx_tf.backend import prepare tf_rep = prepare(onnx_model)
tf_rep.export_graph(path_tf)
TF ⇒ tflite
import tensorflow as tf converter = tf.lite.TFLiteConverter.from_saved_model(path_tf) tflite_model = converter.convert() with open(path_tflite, 'wb') as f: f.write(tflite_model)
tfliteモデルの検証
import numpy as np import tensorflow as tf def softmax(x): f_x = np.exp(x) / np.sum(np.exp(x)) return f_x interpreter = tf.lite.Interpreter(model_path=path_tflite) interpreter.allocate_tensors() input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() input_data = img_np interpreter.set_tensor(input_details[0]['index'], input_data) interpreter.invoke() output_data = interpreter.get_tensor(output_details[0]['index']) output_data = softmax(output_data).squeeze(0) class_id = np.argmax(output_data) score=output_data[class_id] category_name = weights.meta["categories"][class_id] print(f"{category_name}: {100 * score}%")
実行結果
grey fox: 97.5152850151062%
Flutterで画像分類の実装
ライブラリのインストール&設定
Flutterライブラリのインストール
dependencies: image_picker: ^0.8.5+3 tflite_flutter: ^0.9.0 tflite_flutter_helper: git: url: https://github.com/elephantum/tflite_flutter_helper.git ref: a7d7a59a33f7cffa0a2a12ab05625807622cc97a
TFLiteのダウンロード&インストール
install.batをダウンロード。
ダウンロードしたファイルをプロジェクトのルートに配置。
配置後インストール。
install.bat
assetの設定
assets: - assets/model.tflite - assets/labels.txt
実装の手順
メイン画面の作成
import 'dart:io'; import 'package:flutter_tflite_cassification/classifier.dart'; import 'package:flutter_tflite_cassification/classifier_float.dart'; import 'package:image/image.dart' as img; import 'package:flutter/material.dart'; import 'package:image_picker/image_picker.dart'; import 'package:tflite_flutter_helper/tflite_flutter_helper.dart'; void main() { runApp(const MyApp()); } class MyApp extends StatelessWidget { const MyApp({Key? key}) : super(key: key); @override Widget build(BuildContext context) { return MaterialApp( title: 'Flutter Demo', theme: ThemeData( primarySwatch: Colors.blue, useMaterial3: true, ), home: const MyHomePage(title: 'Flutter Classification'), ); } } class MyHomePage extends StatefulWidget { const MyHomePage({Key? key, required this.title}) : super(key: key); final String title; @override State<MyHomePage> createState() => _MyHomePageState(); } class _MyHomePageState extends State<MyHomePage> { final picker = ImagePicker(); String? _imagePath; File? _image; Image? _imageWidget; late Classifier _classifier; Category? category; @override void initState() { super.initState(); _classifier = Classifier(); } Future _getImage() async { final pickedFile = await picker.pickImage(source: ImageSource.gallery); if(pickedFile!=null){ setState(() { _imagePath = (pickedFile.path); _image = File(_imagePath!); _imageWidget = Image.file(_image!); _predict(); }); } } void _predict() async { img.Image imageInput = img.decodeImage(_image!.readAsBytesSync())!; var pred = _classifier.predict(imageInput); setState(() { category = pred; }); } @override Widget build(BuildContext context) { return Scaffold( appBar: AppBar( title: Text(widget.title), ), body: Center( child: Column( mainAxisAlignment: MainAxisAlignment.center, children: <Widget>[ Center( child: _image == null ? const Text("no image") : Container( child: _imageWidget, ), ), const SizedBox(height: 36,), Text(category != null ?category!.label : "", style: const TextStyle(fontSize: 20, fontWeight: FontWeight.w600), ), const SizedBox(height: 8,), Text( category != null? 'Score: ${category!.score.toStringAsFixed(3)}': '', style: const TextStyle(fontSize: 16), ), ], ), ), floatingActionButton: FloatingActionButton( onPressed: _getImage, tooltip: 'Pick Image', child: const Icon(Icons.add_a_photo), ), ); } }
分類クラスの作成
import 'dart:math'; import 'package:image/image.dart'; import 'package:collection/collection.dart'; import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:tflite_flutter_helper/tflite_flutter_helper.dart'; class Classifier { late Interpreter interpreter; late InterpreterOptions _interpreterOptions; late List<int> _inputShape; late List<int> _outputShape; late TensorImage _inputImage; late TensorBuffer _outputBuffer; late TfLiteType _inputType; late TfLiteType _outputType; late var _probabilityProcessor; final String modelName = 'model.tflite'; final String _labelFileName = 'assets/labels.txt'; final int _labelLength = 1000; final NormalizeOp preProcessNormalizeOp = NormalizeOp(114.495, 57.63); final NormalizeOp postProcessNormalizeOp = NormalizeOp(0, 1); late List<String> labels; Classifier({int? numThreads}) { _interpreterOptions = InterpreterOptions(); if (numThreads != null) { _interpreterOptions.threads = numThreads; } loadModel(); loadLabels(); } // モデルデータの読み込み Future<void> loadModel() async { try { interpreter = await Interpreter.fromAsset(modelName, options: _interpreterOptions); print('Interpreter Created Successfully'); _inputShape = interpreter.getInputTensor(0).shape; _outputShape = interpreter.getOutputTensor(0).shape; _inputType = interpreter.getInputTensor(0).type; _outputType = interpreter.getOutputTensor(0).type; _outputBuffer = TensorBuffer.createFixedSize(_outputShape, _outputType); _probabilityProcessor = TensorProcessorBuilder().add(postProcessNormalizeOp).build(); } catch (e) { print('Unable to create interpreter, Caught Exception: ${e.toString()}'); } } // ラベルデータの読み込み Future<void> loadLabels() async { labels = await FileUtil.loadLabels(_labelFileName); if (labels.length == _labelLength) { print('Labels loaded successfully'); } else { print('Unable to load labels'); } } // 画像の前処理 TensorImage _preProcess() { int cropSize = min(_inputImage.height, _inputImage.width); return ImageProcessorBuilder() .add(ResizeWithCropOrPadOp(cropSize, cropSize)) .add(ResizeOp(_inputShape[2], _inputShape[3], ResizeMethod.BILINEAR)) .add(preProcessNormalizeOp) .build() .process(_inputImage); } // 推論 Category predict(Image image) { print("---start predict---"); // 画像の前処理 _inputImage = TensorImage(_inputType); _inputImage.loadImage(image); _inputImage = _preProcess(); int ch = _inputShape[1]; int w = _inputShape[2]; int h = _inputShape[3]; List inputImage = List.filled(ch*w*h,0.0).reshape([1,ch,w,h]); List inputImageList = _inputImage.getTensorBuffer().getBuffer().asFloat32List().reshape([1,w,h,ch]); // (1,224,224,3) => (1,3,224,224) for(int c=0; c < ch; c++){ for(int x = 0; x < w; x++){ for(int y = 0; y < h; y++){ inputImage[0][c][x][y] = inputImageList[0][x][y][c]; } } } // 推論実行 interpreter.run(inputImage, _outputBuffer.getBuffer()); Map<String, double> labeledProb = TensorLabel.fromList( labels, _probabilityProcessor.process(_outputBuffer)) .getMapWithFloatValue(); labeledProb = softmax(labeledProb); final pred = getTopProbability(labeledProb); return Category(pred.key, pred.value); } Map<String, double> softmax(Map<String, double> labeledProb){ Map<String, double> ret; ret = labeledProb; var sum = labeledProb.values.reduce((a ,b) => a + exp(b)); labeledProb.forEach((key, value) { ret[key] =(exp(value))/sum; }); return ret; } MapEntry<String, double> getTopProbability(Map<String, double> labeledProb) { var pq = PriorityQueue<MapEntry<String, double>>(compare); pq.addAll(labeledProb.entries); return pq.first; } int compare(MapEntry<String, double> e1, MapEntry<String, double> e2) { if (e1.value > e2.value) { return -1; } else if (e1.value == e2.value) { return 0; } else { return 1; } } }
実行結果
参考
FlutterとTensorFlow Lite(TFLite)でリアルタイム物体検出をしてみる|うぇるち|note
tflite_flutter_helper/example/image_classification at master · am15h/tflite_flutter_helper