|
|
"""
|
|
|
模块作者:
|
|
|
AI代码结构:刘钰廷、冯雅君
|
|
|
代码优化整理:王昱博、冯昌盛
|
|
|
AI模型训练/纠错:刘钰廷、冯雅君、冯昌盛
|
|
|
代码整合/打包:王昱博
|
|
|
模块用途:
|
|
|
图像分类AI,用于区分车牌的具体类型
|
|
|
"""
|
|
|
|
|
|
import cv2
|
|
|
from PIL import Image
|
|
|
from pathlib import Path
|
|
|
from fastai.vision.all import *
|
|
|
from fastai.metrics import error_rate
|
|
|
from fastai.learner import load_learner
|
|
|
from torchvision.models import resnet34
|
|
|
from fastai.vision.data import ImageBlock
|
|
|
from fastai.vision.core import imagenet_stats
|
|
|
from fastai.data.block import CategoryBlock, DataBlock
|
|
|
from fastai.vision.augment import Resize, aug_transforms
|
|
|
from fastai.vision.learner import cnn_learner, vision_learner
|
|
|
from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize
|
|
|
|
|
|
|
|
|
class ClassificationAI:
|
|
|
@staticmethod
|
|
|
def ConvertImage(cv_img: cv2.Mat) -> Image.Image:
|
|
|
return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)).resize((460, 460))
|
|
|
|
|
|
@classmethod
|
|
|
def TrainAI(cls, data_set_path: str, export_path: str) -> None:
|
|
|
blocks = (ImageBlock, CategoryBlock)
|
|
|
batch_size = 32
|
|
|
dls = DataBlock(
|
|
|
blocks=blocks,
|
|
|
get_items=get_image_files,
|
|
|
splitter=RandomSplitter(),
|
|
|
get_y=parent_label,
|
|
|
item_tfms=Resize(460),
|
|
|
batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)]
|
|
|
).dataloaders(data_set_path, num_workers=4, bs=batch_size)
|
|
|
model = vision_learner(dls, resnet34, metrics=error_rate)
|
|
|
model.fine_tune(5, freeze_epochs=3) # 5 - 训练的轮次, 3 - 冻结的轮次
|
|
|
model.export(Path(export_path) / 'model.pkl')
|
|
|
|
|
|
@classmethod
|
|
|
def PredictImage(cls, image: cv2.Mat, model_path: str) -> tuple:
|
|
|
# 加载模型
|
|
|
model = load_learner(model_path)
|
|
|
# 读取图片并转换为Tensor
|
|
|
img = cls.ConvertImage(image) # 读取图像文件
|
|
|
# 进行预测
|
|
|
pred_class, pred_idx, outputs = model.predict(img)
|
|
|
# 获取置信度
|
|
|
# 检查输出张量的维度
|
|
|
if outputs.dim() == 0:
|
|
|
confidence = float(outputs)
|
|
|
else:
|
|
|
confidence = float(outputs[pred_idx])
|
|
|
return pred_class, confidence
|