1 #coding= utf-8 2 import os 3 import torch 4 from data_pipe import get_data 5 from model import SimpleNet 6 import numpy as np 7 import cv2 8 from PIL import Image 9 10 11 class Infer(object): 12 13 def __init__(self): 14 self.model = SimpleNet() 15 self.model.load_state_dict(torch.load("./models/model_10.pth")) 16 self.model.eval() 17 18 def _infer(self, img_tensor): 19 with torch.no_grad(): 20 result = self.model(img_tensor) 21 if result > 0.5: 22 result = 1 23 else: 24 result = 0 25 return result 26 27 def predict(self, path): 28 img_path_list = [os.path.join(path ,x) for x in os.listdir(path)] 29 for img_path in img_path_list: 30 print(img_path) 31 img = cv2.imread(img_path) 32 img = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) 33 img_tensor = torch.from_numpy(np.asarray(img)).permute(2,0,1).float()/255.0 34 img_tensor = img_tensor.reshape((1, 3, 32, 32)) 35 result = self._infer(img_tensor) 36 print(result) 37 38 39 if __name__ == "__main__": 40 path = "./test_images" 41 Infer().predict(path)