• 决策树算法理解和应用


    决策树算法是一种监督式学习算法,它简单好用,易于解释,在金融科技,数字健康,教育服务,消费互联网等许多领域发挥着积极作用。决策树算法学习的结果,类似下图结构:

    本文首先介绍决策树的原理,然后基于tidymodels框架设计和执行决策树算法以解决实际问题。

    一、决策树算法原理

    决策树算法的理解,可以参考下面的算法伪代码(来源:数据挖掘概念与技术)

    决策树算法需要解决关键问题

    1 如何选择特征做拆分?

    主要采用这些度量方法

    1)信息增益

    最大化变量的信息增益,确定变量的拆分以及先后顺序

    2)增益率

    增益率用于优化信息增益偏向于具有变量值分布不一致所导致的问题。

    3)Gini 指数

    2 如何对树的结构进行裁剪?

    目的:防止学习的模型过拟合(对训练集效果好,而测试集上效果不佳)

    使用统计测量删除不可靠的分支或者有少量样本组成的分支。实际操作中,可以通过设置树生成的一些超参数来控制树的结构,比方说:

    1)树的最大深度max_depth

    2)树的最小划分样本数min_samples_split

    3)数的叶子节点最小样本数min_samples_leaf

    通过裁剪技术,可以让树更加简洁,容易理解,也可提提升模型的泛化性能。

    决策树算法的优点:

    • 简单可解释

    • 可以处理各种数据

    • 非参数模型

    • 稳健

    • 快速

    决策树算法的缺点:

    • 过度拟合问题

    • 不稳定问题

    • 偏差问题

    • 优化问题

    二、决策树算法应用案例

    利用决策树算法预测Scooby Doo monsters是否真实?

    第一步:数据理解与准备

    options(warn = -1)
    library(tidyverse)

    # 数据获取
    scooby_raw <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-07-13/scoobydoo.csv")

    scooby_raw %>%
      filter(monster_amount > 0) %>%
      count(monster_real)

    第二步:从不同维度做洞察

    时间维度

    scooby_raw %>%
      filter(monster_amount > 0) %>%
      count(
        year_aired = 10 * ((lubridate::year(date_aired) + 1) %/% 10),
        monster_real
      ) %>%
      mutate(year_aired = factor(year_aired)) %>%
      ggplot(aes(year_aired, n, fill = monster_real)) +
      geom_col(position = position_dodge(preserve = "single"), alpha = 0.8) +
      labs(x = "Date aired", y = "Monsters per decade", fill = "Real monster?")

    imdb评分维度

    scooby_raw %>%
      filter(monster_amount > 0) %>%
      mutate(imdb = parse_number(imdb)) %>%
      ggplot(aes(imdb, after_stat(density), fill = monster_real)) +
      geom_histogram(position = "identity", alpha = 0.5) +
      labs(x = "IMDB rating", y = "Density", fill = "Real monster?")

    第三步:决策树模型构建

    数据集划分

    训练集,用于训练模型

    测试集,用于评价模型性能

    训练集中利用bootstraps策略用于做超参数选择和优化

    library(tidymodels)
    set.seed(123)
    scooby_split <- scooby_raw %>%
      mutate(
        imdb = parse_number(imdb),
        year_aired = lubridate::year(date_aired)
      ) %>%
      filter(monster_amount > 0, !is.na(imdb)) %>%
      mutate(
        monster_real = case_when(
          monster_real == "FALSE" ~ "fake",
          TRUE ~ "real"
        ),
        monster_real = factor(monster_real)
      ) %>%
      select(year_aired, imdb, monster_real, title) %>%
      initial_split(strata = monster_real)
    scooby_train <- training(scooby_split)
    scooby_test <- testing(scooby_split)

    set.seed(234)
    scooby_folds <- bootstraps(scooby_train, strata = monster_real)
    scooby_folds

    决策树模型设计


    # 设计决策树模型
    tree_spec <-
      decision_tree(
        cost_complexity = tune(),
        tree_depth = tune(),
        min_n = tune()
      ) %>%
      set_mode("classification") %>%
      set_engine("rpart")

    tree_spec
    tree_grid <- grid_regular(cost_complexity(), tree_depth(), min_n(), levels = 4)
    tree_grid

    doParallel::registerDoParallel()

    set.seed(345)
    tree_rs <-
      tune_grid(
        tree_spec,
        monster_real ~ year_aired + imdb,
        resamples = scooby_folds,
        grid = tree_grid,
        metrics = metric_set(accuracy, roc_auc, sensitivity, specificity)
      )

    tree_rs


    第四步:模型性能评价


    # 模型评估和理解
    show_best(tree_rs)
    # 超参数可视化
    autoplot(tree_rs) + theme_light(base_family = "IBMPlexSans")

    # 基于所关注的指标选择最佳模型的超参数
    simpler_tree <- select_by_one_std_err(tree_rs,
                                          -cost_complexity,
                                          metric = "roc_auc"
    )
    # 根据最佳参数重构模型
    final_tree <- finalize_model(tree_spec, simpler_tree)
    final_fit <- fit(final_tree, monster_real ~ year_aired + imdb, scooby_train)

    final_rs <- last_fit(final_tree, monster_real ~ year_aired + imdb, scooby_split)

    collect_metrics(final_rs)

    第五步:模型结果可视化


    # 决策树执行决策的可视化
    library(parttree)

    scooby_train %>%
      ggplot(aes(imdb, year_aired)) +
      geom_parttree(data = final_fit, aes(fill = monster_real), alpha = 0.2) +
      geom_jitter(alpha = 0.7, width = 0.05, height = 0.2, aes(color = monster_real))

     

    参考资料:

    1 Understanding the Mathematics Behind Decision Trees | by Nikita Sharma | Heartbeat (fritz.ai)

    2 https://juliasilge.com/blog/scooby-doo/

    3 https://github.com/grantmcdermott/parttree

  • 相关阅读:
    [Python] Python2 、Python3 urllib 模块对应关系
    [Python] Mac pip安装的模块包路径以及常规python路径
    git 使用详解
    版本控制工具简介
    python基础练习题(二)
    python简介,安装
    python基础练习题(一)
    python练习题之面向对象(三)
    python之input函数,if,else条件语句使用的练习题(一)
    C++ 静态变量
  • 原文地址:https://www.cnblogs.com/purple5252/p/15119385.html
Copyright © 2020-2023  润新知