diff --git a/ver3.py b/ver3.py index aa284c0..5ce18dd 100644 --- a/ver3.py +++ b/ver3.py @@ -15,8 +15,7 @@ from PIL import Image def open_image(image_path): img = Image.open(image_path) img_cvt = img.resize((460,460)) - return img - + return img_cvt def predict_image(image_path): # 加载模型 @@ -36,7 +35,6 @@ def predict_image(image_path): def train(): data_path = Path('G:\\Users\\15819\\Desktop\\Images2') - export_path = Path('G:\\Users\\15819\\Desktop') blocks = (ImageBlock, CategoryBlock) batch_size = 32 dls = DataBlock( @@ -53,7 +51,7 @@ def train(): model.export('G:\\Users\\15819\\Desktop\\model01.pkl') def main(): - #train() + train() image_path = 'G:\\Users\\15819\\Desktop\\Images2\\SmallCar\\京M88888.jpg' pred_class, confidence = predict_image(image_path)