构建移动版图像分类模型
2024 年 07 月 16 日
看到 Google 的一篇例子教程,折腾了一天。然后发现跟实际需求有偏差,没必要深入分析。为了让自己浪费的这一天显得不是那么浪费,还是把学到的东西记录一下。
Google 的例子教程讲述了如何构建一个移动图像分类(Image Classification)模型。具体一点,根据例子构建得到的图像分类模型,可以运行在手机等设备的 tensorflow lite 上,可以对输入的图片进行分类。说实话,听起来是个很有用的功能,不过我没想到具体的场景。另外,这个例子只能对图片进行分类,可以理解为给图片加单个标签。如果是更复杂的情况(比如希望加多个标签),这个例子应该是不适用的。
原始文档链接在这里。该文档中还提供了一个 Jupyter Notebook 文档,可以在 colab 中运行。不过 colab 的资源难以保证,所以我就在服务器上运行了。注意,要保证 tensorflow 安装正常可能有点麻烦,所以建议装好显卡驱动后使用 docker,这里就不讨论 tensorflow docker 如何运行了。
另注:Google 似乎有一篇更全面的教程讲述这个例子。
原始文档其实该说的都说了,我主要把我整理的核心代码和注释放出来。
# 导入必须的库,核心是 mediapipe_model_maker 中的 image_classifier
import os
import tensorflow as tf
from mediapipe_model_maker import image_classifier
import matplotlib.pyplot as plt
# 原始数据集在 https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz
# 解压后可以得到一个目录,里面包含 5 个子目录,每个子目录的名字是一种花的名字,里面的图片是对应的花的图片
image_path = "/tf/notebooks/flower_photos"
# 构建一个标签数组,即目录中所有子目录的名字
print(f"image_path={image_path}")
labels = []
for i in os.listdir(image_path):
if os.path.isdir(os.path.join(image_path, i)):
labels.append(i)
print(f"labels={labels}")
# 下面这段代码可以在 jupyter notebook 环境下显示每个目录的前 5 张图片,非必须运行
%matplotlib inline
NUM_EXAMPLES = 5
for label in labels:
label_dir = os.path.join(image_path, label)
example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]
fig, axs = plt.subplots(1, NUM_EXAMPLES, figsize=(10,2))
for i in range(NUM_EXAMPLES):
axs[i].imshow(plt.imread(os.path.join(label_dir, example_filenames[i])))
axs[i].get_xaxis().set_visible(False)
axs[i].get_yaxis().set_visible(False)
fig.suptitle(f'Showing {NUM_EXAMPLES} examples for {label}')
plt.show()
# 从目录中的图片构建数据集
data = image_classifier.Dataset.from_folder(image_path)
# 将数据集分为两部分,80% 的数据用于训练,20% 的数据用于其他用途
train_data, remaining_data = data.split(0.8)
# 将用于其他用途的数据,也就是那 20%,分为两半
# 一部分做测试数据集,即训练完成后用来检测模型运行效果的数据集
# 另一部分做验证数据集,即每一轮训练完成后用来检验效果的数据集
test_data, validation_data = remaining_data.split(0.5)
# 选择训练参数,这里选择了 MOBILENET_V2 模型
# 支持的模型可以参考:https://ai.google.dev/edge/api/mediapipe/python/mediapipe_model_maker/image_classifier/SupportedModels
# 另外还有其他几种,MOBILENET_V2 似乎是最轻量的
spec = image_classifier.SupportedModels.MOBILENET_V2
hparams = image_classifier.HParams(export_dir="exported_model")
options = image_classifier.ImageClassifierOptions(supported_model=spec, hparams=hparams)
# 使用训练数据集和验证数据集,以及配置的参数开始训练
model = image_classifier.ImageClassifier.create(
train_data = train_data,
validation_data = validation_data,
options=options,
)
# 训练完成后,使用测试数据集进行评估
loss, acc = model.evaluate(test_data)
print(f'Test loss:{loss}, Test accuracy:{acc}')
# 将训练好的模型导出,得到 tflite 文件
model.export_model()
print("re-trained model=exported_model/model.tflite")
# 下面是使用不同参数训练得到第二个模型,并对它进行评估的部分。主要用于数据对比,不是必须运行
# 这一段使用了更高的 dropout_rate,需要跑更多次
hparams=image_classifier.HParams(epochs=15, export_dir="exported_model_2")
options = image_classifier.ImageClassifierOptions(supported_model=spec, hparams=hparams)
options.model_options = image_classifier.ModelOptions(dropout_rate = 0.07)
model_2 = image_classifier.ImageClassifier.create(
train_data = train_data,
validation_data = validation_data,
options=options,
)
loss, accuracy = model_2.evaluate(test_data)
print(f'Model2: Test loss:{loss}, Test accuracy:{acc}')
根据上面的代码,如果有另外的数据集,就可以拿来构建图片分类功能。需要注意的是,数据集不能太小,以例子中的数据集为参考,每种花的图片大约有数百张,整体有三千张以上图片。数据集如果太小的话,无法得到可接受的效果。
原始的 Jupyter Notebook 文档的最后,整理了使用不同模型参数训练后得到的模型大小和测试准确率。我自己使用 MOBILENET_V2 训练后测试得到的准确率大约在 87%。感觉应该是基本可用了。由于跟我的实际需求有偏差,因此关于训练得到的模型的具体使用方法,我没有进一步研究。我的理解大概是,使用一个支持 tflite 格式的推理引擎,引入生成的模型,然后输入图片,就能得到响应了。具体还得根据实际需求进行试验。