You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
1.9 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize
from fastai.learner import load_learner
from fastai.metrics import error_rate
from pathlib import Path
from fastai.data.block import CategoryBlock, DataBlock
from fastai.vision.all import *
from fastai.vision.augment import Resize, aug_transforms
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.learner import cnn_learner, vision_learner
from torchvision.models import resnet34
from PIL import Image
def open_image(image_path):
img = Image.open(image_path)
return img
def predict_image(image_path):
# 加载模型
model = load_learner('G:/Users/15819/Desktop/model01.pkl')
# 读取图片并转换为Tensor
img = open_image(image_path) # 读取指定路径image_path下的图像文件
# 进行预测
pred_class, pred_idx, outputs = model.predict(img)
# 获取置信度
confidence = max(outputs[pred_idx])
return pred_class, confidence
def main():
data_path = Path('G:/Users/15819/Desktop/Images')
export_path = Path('G:/Users//15819/Desktop')
blocks = (ImageBlock, CategoryBlock)
batch_size = 32
dls = DataBlock(
blocks=blocks,
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)]
).dataloaders(data_path, num_workers=4, bs=batch_size)
model = vision_learner(dls, resnet34, metrics=error_rate)
model.fine_tune(5, freeze_epochs=3)
model.export('G:/Users/15819/Desktop/model01.pkl')
image_path = 'G:/Users/15819/Desktop/Images/SmallCar/川A8K059.jpg'
pred_class, confidence = predict_image(image_path)
print(f"图片类别: {pred_class}, 置信度: {confidence}")
if __name__ == '__main__':
main()