You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

57 lines
2.2 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# G:\Users\15819\Desktop\Images
# 导入所需库
from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize
from fastai.learner import load_learner
from fastai.metrics import error_rate
from pathlib import Path
from fastai.data.block import CategoryBlock, DataBlock
from fastai.vision.all import *
from fastai.vision.augment import Resize, aug_transforms
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.learner import cnn_learner, vision_learner
from torchvision.models import resnet34
from PIL import Image
# 设置数据路径和模型保存路径
data_path = Path('G:/Users/15819/Desktop/Images')
export_path = Path('G:/Users//15819/Desktop')
# 定义数据块
blocks = (ImageBlock, CategoryBlock)
# 创建数据加载器
batch_size=32
dls = DataBlock(blocks=blocks,
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)]
).dataloaders(data_path, num_workers=4, bs=batch_size)
# 定义模型结构
model = vision_learner(dls, resnet34, metrics=error_rate)
# 训练模型
model.fine_tune(5, freeze_epochs=3)#表示对模型进行微调fine-tuning微调的目的是利用预训练模型学到的特征来提高模型在新任务上的性能。
# 保存模型
model.export('G:/Users/15819/Desktop/model01.pkl')
# 定义预测函数
#def open_image(image_path):
#pass
def open_image(image_path):
img = Image.open(image_path)
return img
def predict_image(image_path):
# 加载模型
model = load_learner('G:/Users/15819/Desktop/model01.pkl')
# 读取图片并转换为Tensor
img = open_image(image_path)#读取指定路径image_path下的图像文件
# 进行预测
pred_class, pred_idx, outputs = model.predict(img)
# 获取置信度
confidence = max(outputs[pred_idx])
return pred_class, confidence
# 测试预测函数
image_path = 'G:/Users/15819/Desktop/Images/SmallCar/川A8K059.jpg'
pred_class, confidence = predict_image(image_path)
print(f"图片类别: {pred_class}, 置信度: {confidence}")