• PaddlePaddle验证码识别初尝试


    前言

    坑爹金智加了强制验证码登录,让我的小爬虫爬的不是这么快乐了。
    人有计策,我有对策,让我们干它!

    数据集准备

    工欲善其事,必先利其器,这里需要准备验证码图片+正确标签喂给深度学习模型。
    手工标注是不可能手工标注的,让我们偷点懒,做一下简单OCR。

    CAPTCHA_URL = 'http://authserver.{你的学校}.edu.cn/authserver/captcha.html'
    RAW_SAVE_PATH = 'datasets/'
    
    def save(filepath):
        captcha_url = 'http://authserver.{你的学校}.edu.cn/authserver/captcha.html'
    
        res = requests.get(captcha_url)
        with open(filepath, 'wb') as f:
            f.write(res.content)
        time.sleep(0.1)
    
    def gen_filepath():
        for i in range(10 * 10000):
            filename = f"{i:08d}.jpg"
            filepath = os.path.join(RAW_SAVE_PATH, filename)
            if i % 10000 == 0:
                print(filename)
    
            if os.path.exists(filepath):
                continue
    
            yield filepath
            
    if __name__ == '__main__':
        with Pool() as p:
            p.map(save, gen_filepath())
    
    

    这样生成大量带标签的图片,用于接下来的训练。

    训练

    参考链接如下,我们跟着它修改修改:
    https://www.paddlepaddle.org.cn/documentation/docs/zh/tutorial/cv_case/image_ocr/image_ocr.html

    网络

    class Net(pp.nn.Layer):
        def __init__(self, is_infer: bool = False):
            super().__init__()
            self.is_infer = is_infer
    
            self.conv1 = pp.nn.Conv2D(in_channels=1,
                                      out_channels=CHANNELS_BASE,
                                      kernel_size=3)
            self.bn1 = pp.nn.BatchNorm2D(CHANNELS_BASE)
            self.conv2 = pp.nn.Conv2D(in_channels=CHANNELS_BASE,
                                      out_channels=CHANNELS_BASE * 2,
                                      kernel_size=3,
                                      stride=2)
            self.bn2 = pp.nn.BatchNorm2D(CHANNELS_BASE * 2)
            self.conv3 = pp.nn.Conv2D(in_channels=CHANNELS_BASE * 2,
                                      out_channels=CHANNELS_BASE,
                                      kernel_size=1)
            self.linear = pp.nn.Linear(in_features=660,
                                       out_features=YZM_LENGTH + 4)
            self.lstm = pp.nn.LSTM(input_size=CHANNELS_BASE,
                                   hidden_size=CHANNELS_BASE // 2,
                                   direction='bidirectional',
                                   time_major=True)
            self.linear2 = pp.nn.Linear(in_features=CHANNELS_BASE,
                                        out_features=CLASSIFY_NUM)
    
        def forward(self, ipt):
            x = self.conv1(ipt)
            x = pp.nn.functional.relu(x)
            x = self.bn1(x)
            x = self.conv2(x)
            x = pp.nn.functional.relu(x)
            x = self.bn2(x)
            x = self.conv3(x)
            x = pp.nn.functional.relu(x)
            x = pp.tensor.flatten(x, 2)
            x = self.linear(x)
            x = pp.nn.functional.relu(x)
            x = x.transpose([2, 0, 1])
            x = self.lstm(x)[0]
            x = self.linear2(x)
    
            if self.is_infer:
                x = x.transpose([1, 0, 2])
                x = pp.nn.functional.softmax(x)
                x = pp.argmax(x, axis=-1)
            return x
    
    

    损失函数

    class CTCLoss(pp.nn.Layer):
        def forward(self, ipt, label):
            input_lengths = pp.full(shape=[BATCH_SIZE, 1], fill_value=YZM_LENGTH + 4, dtype='int64')
            label_lengths = pp.full(shape=[BATCH_SIZE, 1], fill_value=YZM_LENGTH, dtype='int64')
            loss = pp.nn.functional.ctc_loss(ipt, label, input_lengths, label_lengths, blank=len(CHAR_LIST))
            return loss
    
    

    资源下载地址

    感谢百度提供的免费V100支持,这也是我选用PaddlePaddle的原因。

    代码地址

    https://aistudio.baidu.com/aistudio/projectdetail/2060359

    数据集地址

    https://aistudio.baidu.com/aistudio/datasetdetail/94535

    部署

    一开始给我坑了,把pdopt文件当成model文件,结果一直加载错误,其实要先转化一下:

        inputs = pp.static.InputSpec(shape=[-1, 1, HEIGHT, WIDTH], dtype='float32', name='img')
    
        net = Net(is_infer=True)
        model_state_dict = pp.load(PARAMS_PATH)
        net.set_state_dict(model_state_dict)
    
        optimizer = pp.optimizer.Adam(learning_rate=0.0001, parameters=net.parameters())
        opt_state_dict = pp.load(MODEL_PATH)
        optimizer.set_state_dict(opt_state_dict)
    
        net = to_static(net, input_spec=[inputs])
        pp.jit.save(net, 'models/inference')
    
    

    之后就可以愉快的使用了:

    def predict_captcha(img):
        img = parse_img(img)
        img = pre_process(img)
        img = np.expand_dims(img, axis=0)
    
        input_names = _predictor.get_input_names()
        input_handle = _predictor.get_input_handle(input_names[0])
        input_handle.reshape([1, 1, HEIGHT, WIDTH])
        input_handle.copy_from_cpu(img)
        _predictor.run()
    
        output_names = _predictor.get_output_names()
        output_handle = _predictor.get_output_handle(output_names[0])
        output_data = output_handle.copy_to_cpu()
        return label_arr2text(ctc_decode(output_data[0]))
    
    

    后记

    该模型已部署到我自己的库(https://github.com/Licsber/licsber-pypi)中,对于验证码识别,只需要:

    from licsber.auth import predict_captcha
    

    对于学校的SSO登录:

    from licsber.auth import get_wisedu_session
    

    爬虫又可以快乐起来了(

  • 相关阅读:
    条件语句、循环语句
    var、符号运算、条件语句、三元(目)运算、自加和自减
    js的介绍
    浏览器的差距、ie6 ie7 ie8、符号、html css、BFC、
    单位、浏览器、布局、
    z-index、absolute、marquee滚动条的问题
    js数据类型 方法 函数
    js函数
    全局方法或全局属性
    数据类型
  • 原文地址:https://www.cnblogs.com/licsber/p/14880618.html
Copyright © 2020-2023  润新知