You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
2.2 KiB
Python

2 years ago
"""
主程序作者王昱博
2 years ago
2 years ago
车牌识别系统
使用OCR技术对车牌号码进行识别
使用图像分类AI对车牌种类进行区分
"""
import cv2
from ocr import OCR
from cut_image import ImageCutter
from classification_ai import ClassificationAI
classify_models = ['.\\classify_model\\0.0625.pkl', '.\\classify_model\\0.0625-2.pkl', '.\\classify_model\\0.125.pkl']
def train(train_set_path: str, export_path: str) -> None:
ClassificationAI.TrainAI(train_set_path, export_path)
def main(classify_model_index: int, image_path: str) -> None:
global classify_models
origin_image, gray_image = ImageCutter.ImagePreProcess(image_path)
lpr_text, lpr_conf, cut_image = OCR.RecognizeLicensePlate2(origin_image)
2 years ago
if cut_image is None:
cut_image = ImageCutter.CutPlateRect(origin_image, gray_image)
ocr_text, ocr_type = OCR.RecognizeLicensePlate(cut_image, lpr_text)
if lpr_text is None:
lpr_text = ocr_text
lpr_conf = None
2 years ago
ai_type, ai_conf = ClassificationAI.PredictImage(cut_image, classify_models[classify_model_index])
print(f'识别完成,以下为识别结果:\n车牌号:{lpr_text} [置信度:{lpr_conf}]\n车牌类型:\n\t{ocr_type}(OCR推测)\n\t{ai_type}(AI分类识别)\n\tAI识别置信度{ai_conf}')
if __name__ == '__main__':
2 years ago
result = input('请选择运行模式(训练(t)/识别(r)): ')
if result == 't' or result == 'T':
2 years ago
data_path = input('输入训练集路径: ')
export_path = input('输入模型保存路径: ')
try:
train(data_path, export_path)
except Exception as e:
print(f'训练过程中发生错误: {e}')
else:
print('模型已成功训练')
finally:
print('训练结束')
2 years ago
elif result == 'r' or result == 'R':
2 years ago
model_index = input('选择使用的识别模型(1/2/3): ')
image_path = input('输入图片路径: ')
if (not model_index.isdigit()) or (int(model_index) < 1) or (int(model_index) > 3):
print('输入有误')
else:
main(int(model_index), image_path)
else:
print('输入有误')