|
|
@ -1,18 +1,57 @@
|
|
|
|
# 这是一个示例 Python 脚本。
|
|
|
|
# G:\Users\15819\Desktop\Images
|
|
|
|
|
|
|
|
|
|
|
|
# 按 Shift+F10 执行或将其替换为您的代码。
|
|
|
|
# 导入所需库
|
|
|
|
# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。
|
|
|
|
from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize
|
|
|
|
|
|
|
|
from fastai.learner import load_learner
|
|
|
|
|
|
|
|
from fastai.metrics import error_rate
|
|
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
from fastai.data.block import CategoryBlock, DataBlock
|
|
|
|
|
|
|
|
from fastai.vision.all import *
|
|
|
|
|
|
|
|
from fastai.vision.augment import Resize, aug_transforms
|
|
|
|
|
|
|
|
from fastai.vision.core import imagenet_stats
|
|
|
|
|
|
|
|
from fastai.vision.data import ImageBlock
|
|
|
|
|
|
|
|
from fastai.vision.learner import cnn_learner, vision_learner
|
|
|
|
|
|
|
|
from torchvision.models import resnet34
|
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 设置数据路径和模型保存路径
|
|
|
|
|
|
|
|
data_path = Path('G:/Users/15819/Desktop/Images')
|
|
|
|
|
|
|
|
export_path = Path('G:/Users//15819/Desktop')
|
|
|
|
|
|
|
|
# 定义数据块
|
|
|
|
|
|
|
|
blocks = (ImageBlock, CategoryBlock)
|
|
|
|
|
|
|
|
# 创建数据加载器
|
|
|
|
|
|
|
|
batch_size=32
|
|
|
|
|
|
|
|
dls = DataBlock(blocks=blocks,
|
|
|
|
|
|
|
|
get_items=get_image_files,
|
|
|
|
|
|
|
|
splitter=RandomSplitter(),
|
|
|
|
|
|
|
|
get_y=parent_label,
|
|
|
|
|
|
|
|
item_tfms=Resize(460),
|
|
|
|
|
|
|
|
batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)]
|
|
|
|
|
|
|
|
).dataloaders(data_path, num_workers=4, bs=batch_size)
|
|
|
|
|
|
|
|
# 定义模型结构
|
|
|
|
|
|
|
|
model = vision_learner(dls, resnet34, metrics=error_rate)
|
|
|
|
|
|
|
|
# 训练模型
|
|
|
|
|
|
|
|
model.fine_tune(5, freeze_epochs=3)#表示对模型进行微调(fine-tuning),微调的目的是利用预训练模型学到的特征来提高模型在新任务上的性能。
|
|
|
|
|
|
|
|
# 保存模型
|
|
|
|
|
|
|
|
model.export('G:/Users/15819/Desktop/model01.pkl')
|
|
|
|
|
|
|
|
# 定义预测函数
|
|
|
|
|
|
|
|
#def open_image(image_path):
|
|
|
|
|
|
|
|
#pass
|
|
|
|
|
|
|
|
def open_image(image_path):
|
|
|
|
|
|
|
|
img = Image.open(image_path)
|
|
|
|
|
|
|
|
return img
|
|
|
|
|
|
|
|
|
|
|
|
def print_hi(name):
|
|
|
|
def predict_image(image_path):
|
|
|
|
# 在下面的代码行中使用断点来调试脚本。
|
|
|
|
# 加载模型
|
|
|
|
print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。
|
|
|
|
model = load_learner('G:/Users/15819/Desktop/model01.pkl')
|
|
|
|
|
|
|
|
# 读取图片并转换为Tensor
|
|
|
|
|
|
|
|
img = open_image(image_path)#读取指定路径(image_path)下的图像文件
|
|
|
|
# 按装订区域中的绿色按钮以运行脚本。
|
|
|
|
# 进行预测
|
|
|
|
if __name__ == '__main__':
|
|
|
|
pred_class, pred_idx, outputs = model.predict(img)
|
|
|
|
print_hi('PyCharm')
|
|
|
|
# 获取置信度
|
|
|
|
|
|
|
|
confidence = max(outputs[pred_idx])
|
|
|
|
# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助
|
|
|
|
return pred_class, confidence
|
|
|
|
|
|
|
|
# 测试预测函数
|
|
|
|
#debug
|
|
|
|
image_path = 'G:/Users/15819/Desktop/Images/SmallCar/川A8K059.jpg'
|
|
|
|
|
|
|
|
pred_class, confidence = predict_image(image_path)
|
|
|
|
|
|
|
|
print(f"图片类别: {pred_class}, 置信度: {confidence}")
|