From a06c2e98a369d81cc75507ac994ede52e3d62c07 Mon Sep 17 00:00:00 2001 From: UnknownObject Date: Sat, 2 Dec 2023 17:05:41 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BA=86=E6=89=BE=E4=B8=8D?= =?UTF-8?q?=E5=88=B0=E6=96=87=E4=BB=B6=E7=9A=84=E8=AF=A1=E5=BC=82BUG?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 69 ++++++++++++++++++++++++++++++++++++++++++++------------- ver2.py | 56 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 110 insertions(+), 15 deletions(-) create mode 100644 ver2.py diff --git a/main.py b/main.py index 2da35bf..96461a1 100644 --- a/main.py +++ b/main.py @@ -1,18 +1,57 @@ -# 这是一个示例 Python 脚本。 +# G:\Users\15819\Desktop\Images -# 按 Shift+F10 执行或将其替换为您的代码。 -# 按 双击 Shift 在所有地方搜索类、文件、工具窗口、操作和设置。 +# 导入所需库 +from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize +from fastai.learner import load_learner +from fastai.metrics import error_rate +from pathlib import Path +from fastai.data.block import CategoryBlock, DataBlock +from fastai.vision.all import * +from fastai.vision.augment import Resize, aug_transforms +from fastai.vision.core import imagenet_stats +from fastai.vision.data import ImageBlock +from fastai.vision.learner import cnn_learner, vision_learner +from torchvision.models import resnet34 +from PIL import Image +# 设置数据路径和模型保存路径 +data_path = Path('G:/Users/15819/Desktop/Images') +export_path = Path('G:/Users//15819/Desktop') +# 定义数据块 +blocks = (ImageBlock, CategoryBlock) +# 创建数据加载器 +batch_size=32 +dls = DataBlock(blocks=blocks, + get_items=get_image_files, + splitter=RandomSplitter(), + get_y=parent_label, + item_tfms=Resize(460), + batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)] + ).dataloaders(data_path, num_workers=4, bs=batch_size) +# 定义模型结构 +model = vision_learner(dls, resnet34, metrics=error_rate) +# 训练模型 +model.fine_tune(5, freeze_epochs=3)#表示对模型进行微调(fine-tuning),微调的目的是利用预训练模型学到的特征来提高模型在新任务上的性能。 +# 保存模型 +model.export('G:/Users/15819/Desktop/model01.pkl') +# 定义预测函数 +#def open_image(image_path): + #pass +def open_image(image_path): + img = Image.open(image_path) + return img -def print_hi(name): - # 在下面的代码行中使用断点来调试脚本。 - print(f'Hi, {name}') # 按 Ctrl+F8 切换断点。 - - -# 按装订区域中的绿色按钮以运行脚本。 -if __name__ == '__main__': - print_hi('PyCharm') - -# 访问 https://www.jetbrains.com/help/pycharm/ 获取 PyCharm 帮助 - -#debug \ No newline at end of file +def predict_image(image_path): + # 加载模型 + model = load_learner('G:/Users/15819/Desktop/model01.pkl') + # 读取图片并转换为Tensor + img = open_image(image_path)#读取指定路径(image_path)下的图像文件 + # 进行预测 + pred_class, pred_idx, outputs = model.predict(img) + # 获取置信度 + confidence = max(outputs[pred_idx]) + return pred_class, confidence +# 测试预测函数 +image_path = 'G:/Users/15819/Desktop/Images/SmallCar/川A8K059.jpg' +pred_class, confidence = predict_image(image_path) +print(f"图片类别: {pred_class}, 置信度: {confidence}") \ No newline at end of file diff --git a/ver2.py b/ver2.py new file mode 100644 index 0000000..ef05c55 --- /dev/null +++ b/ver2.py @@ -0,0 +1,56 @@ +from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize +from fastai.learner import load_learner +from fastai.metrics import error_rate +from pathlib import Path +from fastai.data.block import CategoryBlock, DataBlock +from fastai.vision.all import * +from fastai.vision.augment import Resize, aug_transforms +from fastai.vision.core import imagenet_stats +from fastai.vision.data import ImageBlock +from fastai.vision.learner import cnn_learner, vision_learner +from torchvision.models import resnet34 +from PIL import Image + + +def open_image(image_path): + img = Image.open(image_path) + return img + + +def predict_image(image_path): + # 加载模型 + model = load_learner('G:/Users/15819/Desktop/model01.pkl') + # 读取图片并转换为Tensor + img = open_image(image_path) # 读取指定路径(image_path)下的图像文件 + # 进行预测 + pred_class, pred_idx, outputs = model.predict(img) + # 获取置信度 + confidence = max(outputs[pred_idx]) + return pred_class, confidence + + +def main(): + data_path = Path('G:/Users/15819/Desktop/Images') + export_path = Path('G:/Users//15819/Desktop') + blocks = (ImageBlock, CategoryBlock) + batch_size = 32 + dls = DataBlock( + blocks=blocks, + get_items=get_image_files, + splitter=RandomSplitter(), + get_y=parent_label, + item_tfms=Resize(460), + batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)] + ).dataloaders(data_path, num_workers=4, bs=batch_size) + + model = vision_learner(dls, resnet34, metrics=error_rate) + model.fine_tune(5, freeze_epochs=3) + model.export('G:/Users/15819/Desktop/model01.pkl') + + image_path = 'G:/Users/15819/Desktop/Images/SmallCar/川A8K059.jpg' + pred_class, confidence = predict_image(image_path) + print(f"图片类别: {pred_class}, 置信度: {confidence}") + + +if __name__ == '__main__': + main() \ No newline at end of file