You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

82 lines
3.0 KiB
Python

2 years ago
"""
模块作者
AI代码结构刘钰廷冯雅君
代码优化整理王昱博冯昌盛
AI模型训练/纠错刘钰廷冯雅君冯昌盛
代码整合/打包王昱博
模块用途
图像分类AI用于区分车牌的具体类型
"""
import cv2
from PIL import Image
from pathlib import Path
from fastai.vision.all import *
from fastai.metrics import error_rate
from fastai.learner import load_learner
from torchvision.models import resnet34
from fastai.vision.data import ImageBlock
from fastai.vision.core import imagenet_stats
from fastai.data.block import CategoryBlock, DataBlock
from fastai.vision.augment import Resize, aug_transforms
from fastai.vision.learner import cnn_learner, vision_learner
from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize
class ClassificationAI:
@staticmethod
def ConvertImage(cv_img: cv2.Mat) -> Image.Image:
return Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB)).resize((460, 460))
2 years ago
@staticmethod
def ConvertClassifyResult(cla: str) -> str:
if cla == 'ForeignV':
return '外籍车辆'
elif cla == 'In-fieldV':
return '场内车辆'
elif cla == 'large-scaleNewenergyV':
return '大型新能源车辆'
elif cla == 'MediumLarge-sizedV':
return '中/大型车辆'
elif cla == 'MilitaryPoliceEmergencyV':
return '军/警/应急车辆'
elif cla == 'SmallCar':
return '小型轿车'
elif cla == 'SmallNewEnergyV':
return '小型新能源轿车'
else:
return '未知'
2 years ago
@classmethod
def TrainAI(cls, data_set_path: str, export_path: str) -> None:
blocks = (ImageBlock, CategoryBlock)
batch_size = 32
dls = DataBlock(
blocks=blocks,
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)]
2 years ago
).dataloaders(data_set_path, num_workers=0, bs=batch_size)
2 years ago
model = vision_learner(dls, resnet34, metrics=error_rate)
model.fine_tune(5, freeze_epochs=3) # 5 - 训练的轮次, 3 - 冻结的轮次
model.export(Path(export_path) / 'model.pkl')
@classmethod
def PredictImage(cls, image: cv2.Mat, model_path: str) -> tuple:
# 加载模型
model = load_learner(model_path)
# 读取图片并转换为Tensor
img = cls.ConvertImage(image) # 读取图像文件
# 进行预测
pred_class, pred_idx, outputs = model.predict(img)
# 获取置信度
# 检查输出张量的维度
if outputs.dim() == 0:
confidence = float(outputs)
else:
confidence = float(outputs[pred_idx])
2 years ago
pred_class = cls.ConvertClassifyResult(pred_class)
2 years ago
return pred_class, confidence