GLM dans R: modèle linéaire généralisé avec exemple

Table des matières:

Anonim

Qu'est-ce que la régression logistique?

La régression logistique est utilisée pour prédire une classe, c'est-à-dire une probabilité. La régression logistique peut prédire un résultat binaire avec précision.

Imaginez que vous vouliez prédire si un prêt est refusé / accepté en fonction de nombreux attributs. La régression logistique est de la forme 0/1. y = 0 si un prêt est rejeté, y = 1 s'il est accepté.

Un modèle de régression logistique diffère d'un modèle de régression linéaire de deux manières.

  • Tout d'abord, la régression logistique n'accepte qu'une entrée dichotomique (binaire) comme variable dépendante (c'est-à-dire un vecteur de 0 et 1).
  • Deuxièmement, le résultat est mesuré par la fonction de lien probabiliste suivante appelée sigmoïde en raison de sa forme en S.

La sortie de la fonction est toujours comprise entre 0 et 1. Vérifiez l'image ci-dessous

La fonction sigmoïde renvoie des valeurs de 0 à 1. Pour la tâche de classification, nous avons besoin d'une sortie discrète de 0 ou 1.

Pour convertir un flux continu en valeur discrète, nous pouvons définir une borne de décision à 0,5. Toutes les valeurs au-dessus de ce seuil sont classées comme 1

Dans ce tutoriel, vous apprendrez

  • Qu'est-ce que la régression logistique?
  • Comment créer un modèle de revêtement généralisé (GLM)
  • Étape 1) Vérifiez les variables continues
  • Étape 2) Vérifiez les variables de facteur
  • Étape 3) Ingénierie des fonctionnalités
  • Étape 4) Statistique récapitulative
  • Étape 5) Train / ensemble de test
  • Étape 6) Construisez le modèle
  • Étape 7) Évaluer les performances du modèle

Comment créer un modèle de revêtement généralisé (GLM)

Utilisons l' ensemble de données adultes pour illustrer la régression logistique. L '«adulte» est un excellent ensemble de données pour la tâche de classification. L'objectif est de prédire si le revenu annuel en dollars d'un individu dépassera 50 000. L'ensemble de données contient 46033 observations et dix caractéristiques:

  • age: âge de l'individu. Numérique
  • éducation: niveau d'éducation de l'individu. Facteur.
  • marital.status: état matrimonial de l'individu. Facteur c.-à-d. Jamais marié, conjoint-civil,…
  • gender: Sexe de l'individu. Facteur, c'est-à-dire homme ou femme
  • revenu: variable cible. Revenu supérieur ou inférieur à 50K. Facteur ie> 50K, <= 50K

entre autres

library(dplyr)data_adult <-read.csv("https://raw.githubusercontent.com/guru99-edu/R-Programming/master/adult.csv")glimpse(data_adult)

Production:

Observations: 48,842Variables: 10$ x  1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,… $ age  25, 38, 28, 44, 18, 34, 29, 63, 24, 55, 65, 36, 26… $ workclass  Private, Private, Local-gov, Private, ?, Private,… $ education  11th, HS-grad, Assoc-acdm, Some-college, Some-col… $ educational.num  7, 9, 12, 10, 10, 6, 9, 15, 10, 4, 9, 13, 9, 9, 9,… $ marital.status  Never-married, Married-civ-spouse, Married-civ-sp… $ race  Black, White, White, Black, White, White, Black,… $ gender  Male, Male, Male, Male, Female, Male, Male, Male,… $ hours.per.week  40, 50, 40, 40, 30, 30, 40, 32, 40, 10, 40, 40, 39… $ income  <=50K, <=50K, >50K, >50K, <=50K, <=50K, <=50K, >5… 

Nous procéderons comme suit:

  • Étape 1: vérifier les variables continues
  • Étape 2: Vérifiez les variables de facteur
  • Étape 3: Ingénierie des fonctionnalités
  • Étape 4: statistique récapitulative
  • Étape 5: Train / ensemble de test
  • Étape 6: Construisez le modèle
  • Étape 7: Évaluer les performances du modèle
  • étape 8: améliorer le modèle

Votre tâche est de prédire quelle personne aura un revenu supérieur à 50K.

Dans ce tutoriel, chaque étape sera détaillée pour effectuer une analyse sur un jeu de données réel.

Étape 1) Vérifiez les variables continues

Dans la première étape, vous pouvez voir la distribution des variables continues.

continuous <-select_if(data_adult, is.numeric)summary(continuous)

Explication du code

  • continu <- select_if (data_adult, is.numeric): Utilisez la fonction select_if () de la bibliothèque dplyr pour sélectionner uniquement les colonnes numériques
  • résumé (continu): imprimer la statistique récapitulative

Production:

## X age educational.num hours.per.week## Min. : 1 Min. :17.00 Min. : 1.00 Min. : 1.00## 1st Qu.:11509 1st Qu.:28.00 1st Qu.: 9.00 1st Qu.:40.00## Median :23017 Median :37.00 Median :10.00 Median :40.00## Mean :23017 Mean :38.56 Mean :10.13 Mean :40.95## 3rd Qu.:34525 3rd Qu.:47.00 3rd Qu.:13.00 3rd Qu.:45.00## Max. :46033 Max. :90.00 Max. :16.00 Max. :99.00

À partir du tableau ci-dessus, vous pouvez voir que les données ont des échelles et des heures totalement différentes.

Vous pouvez y faire face en deux étapes:

  • 1: Tracez la répartition des heures par semaine
  • 2: Standardiser les variables continues
  1. Tracer la distribution

Regardons de plus près la répartition des heures par semaine

# Histogram with kernel density curvelibrary(ggplot2)ggplot(continuous, aes(x = hours.per.week)) +geom_density(alpha = .2, fill = "#FF6666")

Production:

La variable a beaucoup de valeurs aberrantes et une distribution mal définie. Vous pouvez résoudre partiellement ce problème en supprimant les 0,01% des heures par semaine.

Syntaxe de base du quantile:

quantile(variable, percentile)arguments:-variable: Select the variable in the data frame to compute the percentile-percentile: Can be a single value between 0 and 1 or multiple value. If multiple, use this format: `c(A,B,C,… )- `A`,`B`,`C` and `… ` are all integer from 0 to 1.

Nous calculons le 2 percentile supérieur

top_one_percent <- quantile(data_adult$hours.per.week, .99)top_one_percent

Explication du code

  • quantile (data_adult $ hours.per.week, .99): Calcule la valeur des 99 pour cent du temps de travail

Production:

## 99%## 80 

98 pour cent de la population travaille moins de 80 heures par semaine.

Vous pouvez déposer les observations au-dessus de ce seuil. Vous utilisez le filtre de la bibliothèque dplyr.

data_adult_drop <-data_adult %>%filter(hours.per.week

Production:

## [1] 45537 10 
  1. Standardiser les variables continues

Vous pouvez standardiser chaque colonne pour améliorer les performances car vos données n'ont pas la même échelle. Vous pouvez utiliser la fonction mutate_if de la bibliothèque dplyr. La syntaxe de base est:

mutate_if(df, condition, funs(function))arguments:-`df`: Data frame used to compute the function- `condition`: Statement used. Do not use parenthesis- funs(function): Return the function to apply. Do not use parenthesis for the function

Vous pouvez standardiser les colonnes numériques comme suit:

data_adult_rescale <- data_adult_drop % > %mutate_if(is.numeric, funs(as.numeric(scale(.))))head(data_adult_rescale)

Explication du code

  • mutate_if (is.numeric, funs (scale)): La condition est uniquement une colonne numérique et la fonction est une échelle

Production:

## X age workclass education educational.num## 1 -1.732680 -1.02325949 Private 11th -1.22106443## 2 -1.732605 -0.03969284 Private HS-grad -0.43998868## 3 -1.732530 -0.79628257 Local-gov Assoc-acdm 0.73162494## 4 -1.732455 0.41426100 Private Some-college -0.04945081## 5 -1.732379 -0.34232873 Private 10th -1.61160231## 6 -1.732304 1.85178149 Self-emp-not-inc Prof-school 1.90323857## marital.status race gender hours.per.week income## 1 Never-married Black Male -0.03995944 <=50K## 2 Married-civ-spouse White Male 0.86863037 <=50K## 3 Married-civ-spouse White Male -0.03995944 >50K## 4 Married-civ-spouse Black Male -0.03995944 >50K## 5 Never-married White Male -0.94854924 <=50K## 6 Married-civ-spouse White Male -0.76683128 >50K

Étape 2) Vérifiez les variables de facteur

Cette étape a deux objectifs:

  • Vérifiez le niveau dans chaque colonne catégorielle
  • Définissez de nouveaux niveaux

Nous allons diviser cette étape en trois parties:

  • Sélectionnez les colonnes catégorielles
  • Stocker le graphique à barres de chaque colonne dans une liste
  • Imprimer les graphiques

Nous pouvons sélectionner les colonnes de facteurs avec le code ci-dessous:

# Select categorical columnfactor <- data.frame(select_if(data_adult_rescale, is.factor))ncol(factor)

Explication du code

  • data.frame (select_if (data_adult, is.factor)): Nous stockons les colonnes factor dans factor dans un type de trame de données. La bibliothèque ggplot2 nécessite un objet de bloc de données.

Production:

## [1] 6 

L'ensemble de données contient 6 variables catégorielles

La deuxième étape est plus habile. Vous souhaitez tracer un graphique à barres pour chaque colonne dans le facteur de bloc de données. Il est plus pratique d'automatiser le processus, surtout dans le cas où il y a beaucoup de colonnes.

library(ggplot2)# Create graph for each columngraph <- lapply(names(factor),function(x)ggplot(factor, aes(get(x))) +geom_bar() +theme(axis.text.x = element_text(angle = 90)))

Explication du code

  • lapply (): Utilisez la fonction lapply () pour passer une fonction dans toutes les colonnes de l'ensemble de données. Vous stockez la sortie dans une liste
  • function (x): La fonction sera traitée pour chaque x. Ici x est les colonnes
  • ggplot (factor, aes (get (x))) + geom_bar () + theme (axis.text.x = element_text (angle = 90)): Crée un graphique à barres pour chaque élément x. Notez que pour renvoyer x sous forme de colonne, vous devez l'inclure dans le get ()

La dernière étape est relativement simple. Vous souhaitez imprimer les 6 graphiques.

# Print the graphgraph

Production:

## [[1]]

## ## [[2]]

## ## [[3]]

## ## [[4]]

## ## [[5]]

## ## [[6]]

Remarque: utilisez le bouton suivant pour accéder au graphique suivant

Étape 3) Ingénierie des fonctionnalités

Refonte de l'éducation

À partir du graphique ci-dessus, vous pouvez voir que la variable éducation a 16 niveaux. C'est substantiel, et certains niveaux ont un nombre d'observations relativement faible. Si vous souhaitez améliorer la quantité d'informations que vous pouvez obtenir à partir de cette variable, vous pouvez la refondre au niveau supérieur. À savoir, vous créez des groupes plus importants avec un niveau d'éducation similaire. Par exemple, un faible niveau d'éducation sera converti en abandon. Les niveaux d'enseignement plus élevés seront changés en master.

Voici le détail:

Ancien niveau

Nouveau niveau

Préscolaire

abandonner

10e

Abandonner

11ème

Abandonner

12ème

Abandonner

1er-4e

Abandonner

5e-6e

Abandonner

7e-8e

Abandonner

9ème

Abandonner

HS-Grad

HighGrad

Certains-université

Communauté

Assoc-acdm

Communauté

Assoc-voc

Communauté

Les bacheliers

Les bacheliers

Maîtrise

Maîtrise

Prof-école

Maîtrise

Doctorat

Doctorat

recast_data <- data_adult_rescale % > %select(-X) % > %mutate(education = factor(ifelse(education == "Preschool" | education == "10th" | education == "11th" | education == "12th" | education == "1st-4th" | education == "5th-6th" | education == "7th-8th" | education == "9th", "dropout", ifelse(education == "HS-grad", "HighGrad", ifelse(education == "Some-college" | education == "Assoc-acdm" | education == "Assoc-voc", "Community",ifelse(education == "Bachelors", "Bachelors",ifelse(education == "Masters" | education == "Prof-school", "Master", "PhD")))))))

Explication du code

  • Nous utilisons le verbe mutate de la bibliothèque dplyr. Nous changeons les valeurs de l'éducation avec la déclaration ifelse

Dans le tableau ci-dessous, vous créez une statistique récapitulative pour voir, en moyenne, combien d'années d'études (valeur z) il faut pour atteindre le Bachelor, Master ou PhD.

recast_data % > %group_by(education) % > %summarize(average_educ_year = mean(educational.num),count = n()) % > %arrange(average_educ_year)

Production:

## # A tibble: 6 x 3## education average_educ_year count##   ## 1 dropout -1.76147258 5712## 2 HighGrad -0.43998868 14803## 3 Community 0.09561361 13407## 4 Bachelors 1.12216282 7720## 5 Master 1.60337381 3338## 6 PhD 2.29377644 557

Refonte de l'état matrimonial

Il est également possible de créer des niveaux inférieurs pour l'état matrimonial. Dans le code suivant, vous modifiez le niveau comme suit:

Ancien niveau

Nouveau niveau

Jamais marié

Pas marié

Marié-conjoint-absent

Pas marié

Marié-AF-conjoint

Marié

Marié-civil-conjoint

Séparé

Séparé

Divorcé

Les veuves

Veuve

# Change level marryrecast_data <- recast_data % > %mutate(marital.status = factor(ifelse(marital.status == "Never-married" | marital.status == "Married-spouse-absent", "Not_married", ifelse(marital.status == "Married-AF-spouse" | marital.status == "Married-civ-spouse", "Married", ifelse(marital.status == "Separated" | marital.status == "Divorced", "Separated", "Widow")))))
Vous pouvez vérifier le nombre d'individus dans chaque groupe.
table(recast_data$marital.status)

Production:

## ## Married Not_married Separated Widow## 21165 15359 7727 1286 

Étape 4) Statistique récapitulative

Il est temps de vérifier quelques statistiques sur nos variables cibles. Dans le graphique ci-dessous, vous comptez le pourcentage d'individus gagnant plus de 50 000 personnes compte tenu de leur sexe.

# Plot gender incomeggplot(recast_data, aes(x = gender, fill = income)) +geom_bar(position = "fill") +theme_classic()

Production:

Ensuite, vérifiez si l'origine de l'individu affecte ses gains.

# Plot origin incomeggplot(recast_data, aes(x = race, fill = income)) +geom_bar(position = "fill") +theme_classic() +theme(axis.text.x = element_text(angle = 90))

Production:

Le nombre d'heures de travail par sexe.

# box plot gender working timeggplot(recast_data, aes(x = gender, y = hours.per.week)) +geom_boxplot() +stat_summary(fun.y = mean,geom = "point",size = 3,color = "steelblue") +theme_classic()

Production:

La boîte à moustaches confirme que la répartition du temps de travail correspond à différents groupes. Dans la boîte à moustaches, les deux sexes n'ont pas d'observations homogènes.

Vous pouvez vérifier la densité du temps de travail hebdomadaire par type d'enseignement. Les distributions ont de nombreux choix distincts. Cela peut probablement s'expliquer par le type de contrat aux États-Unis.

# Plot distribution working time by educationggplot(recast_data, aes(x = hours.per.week)) +geom_density(aes(color = education), alpha = 0.5) +theme_classic()

Explication du code

  • ggplot (recast_data, aes (x = hours.per.week)): Un graphique de densité ne nécessite qu'une seule variable
  • geom_density (aes (color = education), alpha = 0.5): L'objet géométrique pour contrôler la densité

Production:

Pour confirmer vos pensées, vous pouvez effectuer un test ANOVA à sens unique:

anova <- aov(hours.per.week~education, recast_data)summary(anova)

Production:

## Df Sum Sq Mean Sq F value Pr(>F)## education 5 1552 310.31 321.2 <2e-16 ***## Residuals 45531 43984 0.97## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Le test ANOVA confirme la différence de moyenne entre les groupes.

Non-linéarité

Avant d'exécuter le modèle, vous pouvez voir si le nombre d'heures travaillées est lié à l'âge.

library(ggplot2)ggplot(recast_data, aes(x = age, y = hours.per.week)) +geom_point(aes(color = income),size = 0.5) +stat_smooth(method = 'lm',formula = y~poly(x, 2),se = TRUE,aes(color = income)) +theme_classic()

Explication du code

  • ggplot (recast_data, aes (x = age, y = hours.per.week)): Définit l'esthétique du graphique
  • geom_point (aes (color = Income), size = 0.5): Construisez le dot plot
  • stat_smooth (): Ajoutez la ligne de tendance avec les arguments suivants:
    • method = 'lm': Tracez la valeur ajustée si la régression linéaire
    • formule = y ~ poly (x, 2): Ajuster une régression polynomiale
    • se = TRUE: Ajouter l'erreur standard
    • aes (couleur = revenu): décomposer le modèle par revenu

Production:

En un mot, vous pouvez tester les termes d'interaction dans le modèle pour détecter l'effet de non-linéarité entre le temps de travail hebdomadaire et d'autres fonctionnalités. Il est important de détecter dans quelles conditions le temps de travail diffère.

Corrélation

Le prochain contrôle consiste à visualiser la corrélation entre les variables. Vous convertissez le type de niveau de facteur en numérique afin de pouvoir tracer une carte thermique contenant le coefficient de corrélation calculé avec la méthode Spearman.

library(GGally)# Convert data to numericcorr <- data.frame(lapply(recast_data, as.integer))# Plot the graphggcorr(corr,method = c("pairwise", "spearman"),nbreaks = 6,hjust = 0.8,label = TRUE,label_size = 3,color = "grey50")

Explication du code

  • data.frame (lapply (recast_data, as.integer)): Convertit les données en numérique
  • ggcorr () trace la carte thermique avec les arguments suivants:
    • method: Méthode de calcul de la corrélation
    • nbreaks = 6: nombre de pause
    • hjust = 0.8: Position de contrôle du nom de la variable dans le tracé
    • label = TRUE: ajoutez des libellés au centre des fenêtres
    • label_size = 3: étiquettes de taille
    • color = "grey50"): Couleur de l'étiquette

Production:

Étape 5) Train / ensemble de test

Toute tâche d'apprentissage automatique supervisé nécessite de répartir les données entre un train et un ensemble de test. Vous pouvez utiliser la «fonction» que vous avez créée dans les autres tutoriels d'apprentissage supervisé pour créer un train / ensemble de test.

set.seed(1234)create_train_test <- function(data, size = 0.8, train = TRUE) {n_row = nrow(data)total_row = size * n_rowtrain_sample <- 1: total_rowif (train == TRUE) {return (data[train_sample, ])} else {return (data[-train_sample, ])}}data_train <- create_train_test(recast_data, 0.8, train = TRUE)data_test <- create_train_test(recast_data, 0.8, train = FALSE)dim(data_train)

Production:

## [1] 36429 9
dim(data_test)

Production:

## [1] 9108 9 

Étape 6) Construisez le modèle

Pour voir comment l'algorithme fonctionne, vous utilisez le package glm (). Le modèle linéaire généralisé est une collection de modèles. La syntaxe de base est:

glm(formula, data=data, family=linkfunction()Argument:- formula: Equation used to fit the model- data: dataset used- Family: - binomial: (link = "logit")- gaussian: (link = "identity")- Gamma: (link = "inverse")- inverse.gaussian: (link = "1/mu^2")- poisson: (link = "log")- quasi: (link = "identity", variance = "constant")- quasibinomial: (link = "logit")- quasipoisson: (link = "log")

Vous êtes prêt à estimer le modèle logistique pour répartir le niveau de revenu entre un ensemble de fonctionnalités.

formula <- income~.logit <- glm(formula, data = data_train, family = 'binomial')summary(logit)

Explication du code

  • formule <- revenu ~.: Créez le modèle adapté
  • logit <- glm (formula, data = data_train, family = 'binomial'): Ajuster un modèle logistique (family = 'binomial') avec les données data_train.
  • summary (logit): imprimer le résumé du modèle

Production:

#### Call:## glm(formula = formula, family = "binomial", data = data_train)## ## Deviance Residuals:## Min 1Q Median 3Q Max## -2.6456 -0.5858 -0.2609 -0.0651 3.1982#### Coefficients:## Estimate Std. Error z value Pr(>|z|)## (Intercept) 0.07882 0.21726 0.363 0.71675## age 0.41119 0.01857 22.146 < 2e-16 ***## workclassLocal-gov -0.64018 0.09396 -6.813 9.54e-12 ***## workclassPrivate -0.53542 0.07886 -6.789 1.13e-11 ***## workclassSelf-emp-inc -0.07733 0.10350 -0.747 0.45499## workclassSelf-emp-not-inc -1.09052 0.09140 -11.931 < 2e-16 ***## workclassState-gov -0.80562 0.10617 -7.588 3.25e-14 ***## workclassWithout-pay -1.09765 0.86787 -1.265 0.20596## educationCommunity -0.44436 0.08267 -5.375 7.66e-08 ***## educationHighGrad -0.67613 0.11827 -5.717 1.08e-08 ***## educationMaster 0.35651 0.06780 5.258 1.46e-07 ***## educationPhD 0.46995 0.15772 2.980 0.00289 **## educationdropout -1.04974 0.21280 -4.933 8.10e-07 ***## educational.num 0.56908 0.07063 8.057 7.84e-16 ***## marital.statusNot_married -2.50346 0.05113 -48.966 < 2e-16 ***## marital.statusSeparated -2.16177 0.05425 -39.846 < 2e-16 ***## marital.statusWidow -2.22707 0.12522 -17.785 < 2e-16 ***## raceAsian-Pac-Islander 0.08359 0.20344 0.411 0.68117## raceBlack 0.07188 0.19330 0.372 0.71001## raceOther 0.01370 0.27695 0.049 0.96054## raceWhite 0.34830 0.18441 1.889 0.05894 .## genderMale 0.08596 0.04289 2.004 0.04506 *## hours.per.week 0.41942 0.01748 23.998 < 2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## (Dispersion parameter for binomial family taken to be 1)## ## Null deviance: 40601 on 36428 degrees of freedom## Residual deviance: 27041 on 36406 degrees of freedom## AIC: 27087#### Number of Fisher Scoring iterations: 6

Le résumé de notre modèle révèle des informations intéressantes. La performance d'une régression logistique est évaluée avec des métriques clés spécifiques.

  • AIC (Akaike Information Criteria): C'est l'équivalent de R2 en régression logistique. Il mesure l'ajustement lorsqu'une pénalité est appliquée au nombre de paramètres. Des valeurs AIC plus petites indiquent que le modèle est plus proche de la vérité.
  • Déviance nulle: adapte le modèle uniquement avec l'interception. Le degré de liberté est n-1. Nous pouvons l'interpréter comme une valeur Chi-carré (valeur ajustée différente du test d'hypothèse de valeur réelle).
  • Déviance résiduelle: modèle avec toutes les variables. Il est également interprété comme un test d'hypothèse du chi carré.
  • Nombre d'itérations de notation de Fisher: nombre d'itérations avant la convergence.

La sortie de la fonction glm () est stockée dans une liste. Le code ci-dessous montre tous les éléments disponibles dans la variable logit que nous avons construite pour évaluer la régression logistique.

# La liste est très longue, n'imprimez que les trois premiers éléments

lapply(logit, class)[1:3]

Production:

## $coefficients## [1] "numeric"#### $residuals## [1] "numeric"#### $fitted.values## [1] "numeric"

Chaque valeur peut être extraite avec le signe $ suivi du nom des métriques. Par exemple, vous avez stocké le modèle sous logit. Pour extraire les critères AIC, vous utilisez:

logit$aic

Production:

## [1] 27086.65

Étape 7) Évaluer les performances du modèle

Matrice de confusion

La matrice de confusion est un meilleur choix pour évaluer les performances de classification par rapport aux différentes métriques que vous avez vues auparavant. L'idée générale est de compter le nombre de fois où les instances True sont classées comme False.

Pour calculer la matrice de confusion, vous devez d'abord disposer d'un ensemble de prédictions afin qu'elles puissent être comparées aux cibles réelles.

predict <- predict(logit, data_test, type = 'response')# confusion matrixtable_mat <- table(data_test$income, predict > 0.5)table_mat

Explication du code

  • prédire (logit, data_test, type = 'response'): calcule la prédiction sur l'ensemble de test. Définissez type = 'response' pour calculer la probabilité de réponse.
  • table (data_test $ revenu, prédire> 0,5): Calcule la matrice de confusion. prédire> 0,5 signifie qu'il renvoie 1 si les probabilités prédites sont supérieures à 0,5, sinon 0.

Production:

#### FALSE TRUE## <=50K 6310 495## >50K 1074 1229

Chaque ligne d'une matrice de confusion représente une cible réelle, tandis que chaque colonne représente une cible prédite. La première ligne de cette matrice considère le revenu inférieur à 50k (la classe Faux): 6241 ont été correctement classés comme des personnes ayant un revenu inférieur à 50k ( Vrai négatif ), tandis que le reste a été incorrectement classé comme supérieur à 50k ( Faux positif ). La deuxième ligne considère le revenu supérieur à 50k, la classe positive était de 1229 ( vrai positif ), tandis que le vrai négatif était de 1074.

Vous pouvez calculer la précision du modèle en additionnant le vrai positif + le vrai négatif sur l'observation totale

accuracy_Test <- sum(diag(table_mat)) / sum(table_mat)accuracy_Test

Explication du code

  • sum (diag (table_mat)): Somme de la diagonale
  • sum (table_mat): somme de la matrice.

Production:

## [1] 0.8277339 

Le modèle semble souffrir d'un problème, il surestime le nombre de faux négatifs. C'est ce qu'on appelle le paradoxe du test de précision . Nous avons déclaré que l'exactitude est le rapport entre les prévisions correctes et le nombre total de cas. On peut avoir une précision relativement élevée mais un modèle inutile. Cela arrive quand il y a une classe dominante. Si vous regardez en arrière la matrice de confusion, vous pouvez voir que la plupart des cas sont classés comme vrais négatifs. Imaginez maintenant, le modèle a classé toutes les classes comme négatives (c'est-à-dire inférieures à 50k). Vous auriez une précision de 75 pour cent (6718/6718 + 2257). Votre modèle fonctionne mieux mais a du mal à distinguer le vrai positif du vrai négatif.

Dans une telle situation, il est préférable d'avoir une métrique plus concise. On peut regarder:

  • Précision = TP / (TP + FP)
  • Rappel = TP / (TP + FN)

Précision vs rappel

La précision examine l'exactitude de la prédiction positive. Le rappel est le rapport des instances positives qui sont correctement détectées par le classifieur;

Vous pouvez construire deux fonctions pour calculer ces deux métriques

  1. Construire la précision
precision <- function(matrix) {# True positivetp <- matrix[2, 2]# false positivefp <- matrix[1, 2]return (tp / (tp + fp))}

Explication du code

  • mat [1,1]: Renvoie la première cellule de la première colonne de la trame de données, c'est-à-dire le vrai positif
  • mat [1,2]; Renvoie la première cellule de la deuxième colonne de la trame de données, c'est-à-dire le faux positif
recall <- function(matrix) {# true positivetp <- matrix[2, 2]# false positivefn <- matrix[2, 1]return (tp / (tp + fn))}

Explication du code

  • mat [1,1]: Renvoie la première cellule de la première colonne de la trame de données, c'est-à-dire le vrai positif
  • mat [2,1]; Renvoie la deuxième cellule de la première colonne de la trame de données, c'est-à-dire le faux négatif

Vous pouvez tester vos fonctions

prec <- precision(table_mat)precrec <- recall(table_mat)rec

Production:

## [1] 0.712877## [2] 0.5336518

Lorsque le modèle dit qu'il s'agit d'un individu au-dessus de 50k, il est correct dans seulement 54% des cas et peut réclamer des individus au-dessus de 50k dans 72% des cas.

Vous pouvez créer le est une moyenne harmonique de ces deux métriques, ce qui signifie qu'il donne plus de poids aux valeurs inférieures.

f1 <- 2 * ((prec * rec) / (prec + rec))f1

Production:

## [1] 0.6103799 

Compromis entre précision et rappel

Il est impossible d'avoir à la fois une haute précision et un rappel élevé.

Si nous augmentons la précision, l'individu correct sera mieux prédit, mais nous en manquerions beaucoup (rappel inférieur). Dans certaines situations, nous préférons une précision plus élevée que le rappel. Il existe une relation concave entre la précision et le rappel.

  • Imaginez, vous devez prédire si un patient a une maladie. Vous voulez être le plus précis possible.
  • Si vous avez besoin de détecter des personnes potentiellement frauduleuses dans la rue grâce à la reconnaissance faciale, il serait préférable d'attraper de nombreuses personnes qualifiées de frauduleuses même si la précision est faible. La police pourra libérer l'individu non frauduleux.

La courbe ROC

La courbe des caractéristiques de fonctionnement du récepteur est un autre outil couramment utilisé avec la classification binaire. Elle est très similaire à la courbe précision / rappel, mais au lieu de tracer la précision par rapport au rappel, la courbe ROC montre le taux de vrais positifs (c'est-à-dire le rappel) par rapport au taux de faux positifs. Le taux de faux positifs est le rapport des instances négatives qui sont incorrectement classées comme positives. Il est égal à un moins le vrai taux négatif. Le vrai taux négatif est également appelé spécificité . Par conséquent, la courbe ROC trace la sensibilité (rappel) par rapport à la spécificité 1

Pour tracer la courbe ROC, nous devons installer une bibliothèque appelée RORC. On peut trouver dans la bibliothèque conda. Vous pouvez taper le code:

conda install -cr r-rocr --oui

Nous pouvons tracer le ROC avec les fonctions prediction () et performance ().

library(ROCR)ROCRpred <- prediction(predict, data_test$income)ROCRperf <- performance(ROCRpred, 'tpr', 'fpr')plot(ROCRperf, colorize = TRUE, text.adj = c(-0.2, 1.7))

Explication du code

  • prediction (predict, data_test $ Income): la bibliothèque ROCR doit créer un objet de prédiction pour transformer les données d'entrée
  • performance (ROCRpred, 'tpr', 'fpr'): retourne les deux combinaisons à produire dans le graphe. Ici, tpr et fpr sont construits. Tot tracer la précision et le rappel ensemble, utilisez "prec", "rec".

Production:

Étape 8) Améliorez le modèle

Vous pouvez essayer d'ajouter une non-linéarité au modèle avec l'interaction entre

  • âge et heures par semaine
  • sexe et heures par semaine.

Vous devez utiliser le test de score pour comparer les deux modèles

formula_2 <- income~age: hours.per.week + gender: hours.per.week + .logit_2 <- glm(formula_2, data = data_train, family = 'binomial')predict_2 <- predict(logit_2, data_test, type = 'response')table_mat_2 <- table(data_test$income, predict_2 > 0.5)precision_2 <- precision(table_mat_2)recall_2 <- recall(table_mat_2)f1_2 <- 2 * ((precision_2 * recall_2) / (precision_2 + recall_2))f1_2

Production:

## [1] 0.6109181 

Le score est légèrement supérieur au précédent. Vous pouvez continuer à travailler sur les données pour essayer de battre le score.

Résumé

Nous pouvons résumer la fonction pour entraîner une régression logistique dans le tableau ci-dessous:

Paquet

Objectif

fonction

argument

-

Créer un ensemble de données train / test

create_train_set ()

données, taille, train

glm

Entraîner un modèle linéaire généralisé

glm ()

formule, données, famille *

glm

Résumer le modèle

résumé()

modèle ajusté

base

Faire des prédictions

prédire()

modèle ajusté, ensemble de données, type = 'réponse'

base

Créer une matrice de confusion

table()

y, prédire ()

base

Créer un score de précision

somme (diag (table ()) / somme (table ()

ROCR

Créer ROC: Étape 1 Créer une prédiction

prédiction()

prédire (), y

ROCR

Créer ROC: Étape 2 Créer des performances

performance()

prédiction (), 'tpr', 'fpr'

ROCR

Créer un ROC: Étape 3 Tracer un graphique

terrain()

performance()

Les autres types de modèles GLM sont:

- binôme: (link = "logit")

- gaussien: (lien = "identité")

- Gamma: (lien = "inverse")

- inverse.gaussian: (lien = "1 / mu 2")

- poisson: (lien = "log")

- quasi: (lien = "identité", variance = "constante")

- quasibinomial: (lien = "logit")

- quasipoisson: (lien = "log")