You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

56 lines
1.9 KiB
Python

from fastai.data.transforms import get_image_files, parent_label, RandomSplitter, Normalize
from fastai.learner import load_learner
from fastai.metrics import error_rate
from pathlib import Path
from fastai.data.block import CategoryBlock, DataBlock
from fastai.vision.all import *
from fastai.vision.augment import Resize, aug_transforms
from fastai.vision.core import imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.learner import cnn_learner, vision_learner
from torchvision.models import resnet34
from PIL import Image
def open_image(image_path):
img = Image.open(image_path)
return img
def predict_image(image_path):
# 加载模型
model = load_learner('G:/Users/15819/Desktop/model01.pkl')
# 读取图片并转换为Tensor
img = open_image(image_path) # 读取指定路径image_path下的图像文件
# 进行预测
pred_class, pred_idx, outputs = model.predict(img)
# 获取置信度
confidence = max(outputs[pred_idx])
return pred_class, confidence
def main():
data_path = Path('G:/Users/15819/Desktop/Images')
export_path = Path('G:/Users//15819/Desktop')
blocks = (ImageBlock, CategoryBlock)
batch_size = 32
dls = DataBlock(
blocks=blocks,
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=parent_label,
item_tfms=Resize(460),
batch_tfms=[*aug_transforms(size=224, min_scale=0.75), Normalize.from_stats(*imagenet_stats)]
).dataloaders(data_path, num_workers=4, bs=batch_size)
model = vision_learner(dls, resnet34, metrics=error_rate)
model.fine_tune(5, freeze_epochs=3)
model.export('G:/Users/15819/Desktop/model01.pkl')
image_path = 'G:/Users/15819/Desktop/Images/SmallCar/川A8K059.jpg'
pred_class, confidence = predict_image(image_path)
print(f"图片类别: {pred_class}, 置信度: {confidence}")
if __name__ == '__main__':
main()