You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

82 lines
3.0 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
模块作者:
AI代码结构刘钰廷、冯雅君
代码优化整理:王昱博、冯昌盛
AI模型训练/纠错:刘钰廷、冯雅君、冯昌盛
代码整合/打包:王昱博
模块用途:
图像分类AI用于区分车牌的具体类型
"""
import cv2
from PIL import Image
from pathlib import Path
from fastai.vision.all import *
from fastai.metrics import error_rate
from fastai.learner import load_learner
from torchvision.models import resnet34
from fastai.vision.data import ImageBlock
from fastai.vision.core import imagenet_stats
from fastai.data.block import CategoryBlock, DataBlock
from fastai.vision.augment import Resize, aug_transforms
from fastai.vision.learner import cnn_learner, vision_learner
from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize
class ClassificationAI:
@staticmethod
def ConvertImage(cv_img: cv2.Mat) -> Image.Image:
return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)).resize((460, 460))
@staticmethod
def ConvertClassifyResult(cla: str) -> str:
if cla == 'ForeignV':
return '外籍车辆'
elif cla == 'In-fieldV':
return '场内车辆'
elif cla == 'large-scaleNewenergyV':
return '大型新能源车辆'
elif cla == 'MediumLarge-sizedV':
return '中/大型车辆'
elif cla == 'MilitaryPoliceEmergencyV':
return '军/警/应急车辆'
elif cla == 'SmallCar':
return '小型轿车'
elif cla == 'SmallNewEnergyV':
return '小型新能源轿车'
else:
return '未知'
@classmethod
def TrainAI(cls, data_set_path: str, export_path: str) -> None:
blocks = (ImageBlock, CategoryBlock)
batch_size = 32
dls = DataBlock(
blocks=blocks,
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)]
).dataloaders(data_set_path, num_workers=0, bs=batch_size)
model = vision_learner(dls, resnet34, metrics=error_rate)
model.fine_tune(5, freeze_epochs=3) # 5 - 训练的轮次, 3 - 冻结的轮次
model.export(Path(export_path) / 'model.pkl')
@classmethod
def PredictImage(cls, image: cv2.Mat, model_path: str) -> tuple:
# 加载模型
model = load_learner(model_path)
# 读取图片并转换为Tensor
img = cls.ConvertImage(image) # 读取图像文件
# 进行预测
pred_class, pred_idx, outputs = model.predict(img)
# 获取置信度
# 检查输出张量的维度
if outputs.dim() == 0:
confidence = float(outputs)
else:
confidence = float(outputs[pred_idx])
pred_class = cls.ConvertClassifyResult(pred_class)
return pred_class, confidence