【Flutter】Androidで画像分類を行う方法

概要

Flutterで機械学習モデルによる画像分類を試してみた。

おおまかな手順は、

  1. Pytorchでモデルを作成
  2. tflite形式に変換
  3. Flutterで実装

開発環境

OS Windows11
Editor VS Code
Flutter Ver. 3.0.5
Smart Phone Android

分類モデルの作成

Pytorchで画像分類のモデルを作成し、tflite形式に変換。

コードは全てGoogle Colabで実行。

Pytorchで分類モデルを作成する

ライブラリのインストール

import os
import numpy as np
import torch
from torchvision import models
import torchvision.utils as vutils
from torchvision.io import read_image

設定値

batch_size = 1
channel = 3
width = 224
height = 224

path_pth = "/content/model.pth"
path_onnx = "/content/model.onnx"
path_tf = "/content/tf_model"
path_tflite = "/content/model.tflite"

モデルの保存

net = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights)
print(net)
torch.save(net.state_dict(), path_pth)

モデルの検証

weights = models.MobileNet_V3_Small_Weights.DEFAULT
net.eval()

preprocess = weights.transforms()

img = read_image("/content/greyfox.jpg")
batch = preprocess(img).unsqueeze(0)
img_np = batch.to('cpu').detach().numpy().copy()

prediction = net(batch).squeeze(0).softmax(0)
class_id = prediction.argmax().item()
score = prediction[class_id].item()
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")

実行結果

grey fox: 97.51530885696411%

ラベルデータの作成

category_name = weights.meta["categories"]
print(category_name)

with open('labels.txt', 'w') as f:
  for x in category_name:
    f.write("%s\n" % x)

PytorchモデルをTFLIte形式に変換

Pytroch ⇒ onnx

conver_model = net
conver_model.load_state_dict(torch.load(path_pth, map_location='cpu'))
conver_model.cpu()
conver_model.eval()

sample_input = torch.rand(batch_size ,channel, width, height)

torch.onnx.export(
    conver_model,
    sample_input,
    path_onnx,
    opset_version=12,
    input_names=['input'],
    output_names=['output']
)

onnx ⇒ TF

!git clone https://github.com/onnx/onnx-tensorflow.git
%cd onnx-tensorflow
!pip install -e .
import onnx

onnx_model = onnx.load(path_onnx)
from onnx_tf.backend import  prepare

tf_rep = prepare(onnx_model)
tf_rep.export_graph(path_tf)

TF ⇒ tflite

import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model(path_tf)
tflite_model = converter.convert()

with open(path_tflite, 'wb') as f:
    f.write(tflite_model)

tfliteモデルの検証

import numpy as np
import tensorflow as tf

def softmax(x):
    f_x = np.exp(x) / np.sum(np.exp(x))
    return f_x

interpreter = tf.lite.Interpreter(model_path=path_tflite)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_data = img_np
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

output_data = interpreter.get_tensor(output_details[0]['index'])

output_data = softmax(output_data).squeeze(0)
class_id = np.argmax(output_data)
score=output_data[class_id]
category_name = weights.meta["categories"][class_id]
print(f"{category_name}: {100 * score}%")

実行結果

grey fox: 97.5152850151062%

Flutterで画像分類の実装

ライブラリのインストール&設定

Flutterライブラリのインストール

dependencies:

  image_picker: ^0.8.5+3
  tflite_flutter: ^0.9.0
  tflite_flutter_helper:
    git:
      url: https://github.com/elephantum/tflite_flutter_helper.git
      ref: a7d7a59a33f7cffa0a2a12ab05625807622cc97a

TFLiteのダウンロード&インストール

install.batをダウンロード。

ダウンロードしたファイルをプロジェクトのルートに配置。

配置後インストール。

install.bat

assetの設定

assets:
    - assets/model.tflite
    - assets/labels.txt

実装の手順

メイン画面の作成

import 'dart:io';
import 'package:flutter_tflite_cassification/classifier.dart';
import 'package:flutter_tflite_cassification/classifier_float.dart';
import 'package:image/image.dart' as img;
import 'package:flutter/material.dart';
import 'package:image_picker/image_picker.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

void main() {
  runApp(const MyApp());
}

class MyApp extends StatelessWidget {
  const MyApp({Key? key}) : super(key: key);

  @override
  Widget build(BuildContext context) {
    return MaterialApp(
      title: 'Flutter Demo',
      theme: ThemeData(
        primarySwatch: Colors.blue,
        useMaterial3: true,
      ),
      home: const MyHomePage(title: 'Flutter Classification'),
    );
  }
}

class MyHomePage extends StatefulWidget {
  const MyHomePage({Key? key, required this.title}) : super(key: key);
  final String title;

  @override
  State<MyHomePage> createState() => _MyHomePageState();
}

class _MyHomePageState extends State<MyHomePage> {
  final picker = ImagePicker();
  String? _imagePath;
  File? _image;
  Image? _imageWidget;

  late Classifier _classifier;

  Category? category;

  @override
  void initState() {
    super.initState();
    _classifier = Classifier();
  }

  Future _getImage() async {
    final pickedFile = await picker.pickImage(source: ImageSource.gallery);
    if(pickedFile!=null){
    setState(() {
        _imagePath = (pickedFile.path);
        _image = File(_imagePath!);
        _imageWidget = Image.file(_image!);

        _predict();
      });
    }
  }

  void _predict() async {
    img.Image imageInput = img.decodeImage(_image!.readAsBytesSync())!;
    var pred = _classifier.predict(imageInput);
    setState(() {
      category = pred;
    });
  }

  @override
  Widget build(BuildContext context) {
    return Scaffold(
      appBar: AppBar(
        title: Text(widget.title),
      ),
      body: Center(
        child: Column(
          mainAxisAlignment: MainAxisAlignment.center,
          children: <Widget>[
            Center(
              child: _image == null
                  ? const Text("no image")
                  : Container(
                      child: _imageWidget,
                    ),
            ),
            const SizedBox(height: 36,),
            Text(category != null ?category!.label : "",
            style: const TextStyle(fontSize: 20, fontWeight: FontWeight.w600),
            ),
            const SizedBox(height: 8,),
            Text(
            category != null? 'Score: ${category!.score.toStringAsFixed(3)}': '',
            style: const TextStyle(fontSize: 16),
          ),
          ],
        ),
      ),
      floatingActionButton: FloatingActionButton(
        onPressed: _getImage,
        tooltip: 'Pick Image',
        child: const Icon(Icons.add_a_photo),
      ),
    );
  }
}

分類クラスの作成

import 'dart:math';
import 'package:image/image.dart';
import 'package:collection/collection.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';

class Classifier {
  late Interpreter interpreter;
  late InterpreterOptions _interpreterOptions;

  late List<int> _inputShape;
  late List<int> _outputShape;

  late TensorImage _inputImage;
  late TensorBuffer _outputBuffer;

  late TfLiteType _inputType;
  late TfLiteType _outputType;

  late var _probabilityProcessor;

  final String modelName = 'model.tflite';
  final String _labelFileName = 'assets/labels.txt';
  final int _labelLength = 1000;
  final NormalizeOp  preProcessNormalizeOp = NormalizeOp(114.495, 57.63);
  final NormalizeOp  postProcessNormalizeOp = NormalizeOp(0, 1);
  late List<String> labels;

  Classifier({int? numThreads}) {
    _interpreterOptions = InterpreterOptions();

    if (numThreads != null) {
      _interpreterOptions.threads = numThreads;
    }

    loadModel();
    loadLabels();
  }

  // モデルデータの読み込み
  Future<void> loadModel() async {
    try {
      interpreter =
          await Interpreter.fromAsset(modelName, options: _interpreterOptions);
      print('Interpreter Created Successfully');

      _inputShape = interpreter.getInputTensor(0).shape;
      _outputShape = interpreter.getOutputTensor(0).shape;
      _inputType = interpreter.getInputTensor(0).type;
      _outputType = interpreter.getOutputTensor(0).type;

      _outputBuffer = TensorBuffer.createFixedSize(_outputShape, _outputType);
      _probabilityProcessor =
          TensorProcessorBuilder().add(postProcessNormalizeOp).build();
    } catch (e) {
      print('Unable to create interpreter, Caught Exception: ${e.toString()}');
    }
  }

  // ラベルデータの読み込み
  Future<void> loadLabels() async {
    labels = await FileUtil.loadLabels(_labelFileName);
    if (labels.length == _labelLength) {
      print('Labels loaded successfully');
    } else {
      print('Unable to load labels');
    }
  }

  // 画像の前処理
  TensorImage _preProcess() {
    int cropSize = min(_inputImage.height, _inputImage.width);
    return ImageProcessorBuilder()
        .add(ResizeWithCropOrPadOp(cropSize, cropSize))
        .add(ResizeOp(_inputShape[2], _inputShape[3], ResizeMethod.BILINEAR))
        .add(preProcessNormalizeOp)
        .build()
        .process(_inputImage);
  }

  // 推論
  Category predict(Image image) {
    print("---start predict---");
    // 画像の前処理
    _inputImage = TensorImage(_inputType);
    _inputImage.loadImage(image);
    _inputImage = _preProcess();

    int ch = _inputShape[1];
    int w = _inputShape[2];
    int h = _inputShape[3];

    List inputImage = List.filled(ch*w*h,0.0).reshape([1,ch,w,h]);
    List inputImageList = _inputImage.getTensorBuffer().getBuffer().asFloat32List().reshape([1,w,h,ch]);
    
    // (1,224,224,3) => (1,3,224,224)
    for(int c=0; c < ch; c++){
      for(int x = 0; x < w; x++){
        for(int y = 0; y < h; y++){
          inputImage[0][c][x][y] = inputImageList[0][x][y][c];
        }
      }
    }
    
    // 推論実行
    interpreter.run(inputImage, _outputBuffer.getBuffer());
    Map<String, double> labeledProb = TensorLabel.fromList(
            labels, _probabilityProcessor.process(_outputBuffer))
        .getMapWithFloatValue();

    labeledProb = softmax(labeledProb);
    final pred = getTopProbability(labeledProb);
    return Category(pred.key, pred.value);
  }

  Map<String, double> softmax(Map<String, double> labeledProb){
    Map<String, double> ret;
    ret = labeledProb;
    var sum = labeledProb.values.reduce((a ,b) => a + exp(b));
    labeledProb.forEach((key, value) { 
      ret[key] =(exp(value))/sum; 
    });

    return ret;
  }

  MapEntry<String, double> getTopProbability(Map<String, double> labeledProb) {
    var pq = PriorityQueue<MapEntry<String, double>>(compare);
    pq.addAll(labeledProb.entries);
    return pq.first;
  }

  int compare(MapEntry<String, double> e1, MapEntry<String, double> e2) {
    if (e1.value > e2.value) {
      return -1;
    } else if (e1.value == e2.value) {
      return 0;
    } else {
      return 1;
    }
  }
}

実行結果

参考

FlutterとTensorFlow Lite(TFLite)でリアルタイム物体検出をしてみる|うぇるち|note

tflite_flutter_helper/example/image_classification at master · am15h/tflite_flutter_helper