earth
packageFunction | Works |
---|---|
tidypredict_fit() , tidypredict_sql() ,
parse_model() |
✔ |
tidypredict_to_column() |
✔ |
tidypredict_test() |
✔ |
tidypredict_interval() ,
tidypredict_sql_interval() |
✗ |
parsnip |
✔ |
tidypredict_
functionslibrary(earth)
data("etitanic", package = "earth")
model <- earth(age ~ sibsp + parch, data = etitanic, degree = 3)
Create the R formula
tidypredict_fit(model)
#> 22.2918960405404 + (ifelse(parch < 2, 2 - parch, 0) * 4.85356462114363) +
#> (ifelse(parch > 2, parch - 2, 0) * 13.0493891423278) + (ifelse(parch >
#> 4, parch - 4, 0) * -18.8998708031821) + (ifelse(sibsp > 1,
#> sibsp - 1, 0) * -7.71566779782019) + (ifelse(sibsp > 1, sibsp -
#> 1, 0) * ifelse(parch < 1, 1 - parch, 0) * 7.40395975552276) +
#> (ifelse(sibsp > 1, sibsp - 1, 0) * ifelse(parch > 1, parch -
#> 1, 0) * 4.41874354843216)
SQL output example
tidypredict_sql(model, dbplyr::simulate_odbc())
#> <SQL> (((((22.2918960405404 + (CASE WHEN (`parch` < 2.0) THEN (2.0 - `parch`) WHEN NOT (`parch` < 2.0) THEN 0.0 END * 4.85356462114363)) + (CASE WHEN (`parch` > 2.0) THEN (`parch` - 2.0) WHEN NOT (`parch` > 2.0) THEN 0.0 END * 13.0493891423278)) + (CASE WHEN (`parch` > 4.0) THEN (`parch` - 4.0) WHEN NOT (`parch` > 4.0) THEN 0.0 END * -18.8998708031821)) + (CASE WHEN (`sibsp` > 1.0) THEN (`sibsp` - 1.0) WHEN NOT (`sibsp` > 1.0) THEN 0.0 END * -7.71566779782019)) + ((CASE WHEN (`sibsp` > 1.0) THEN (`sibsp` - 1.0) WHEN NOT (`sibsp` > 1.0) THEN 0.0 END * CASE WHEN (`parch` < 1.0) THEN (1.0 - `parch`) WHEN NOT (`parch` < 1.0) THEN 0.0 END) * 7.40395975552276)) + ((CASE WHEN (`sibsp` > 1.0) THEN (`sibsp` - 1.0) WHEN NOT (`sibsp` > 1.0) THEN 0.0 END * CASE WHEN (`parch` > 1.0) THEN (`parch` - 1.0) WHEN NOT (`parch` > 1.0) THEN 0.0 END) * 4.41874354843216)
Add the prediction to the original table
library(dplyr)
etitanic %>%
tidypredict_to_column(model) %>%
glimpse()
#> Rows: 1,046
#> Columns: 7
#> $ pclass <fct> 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1st, 1…
#> $ survived <int> 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1…
#> $ sex <fct> female, male, female, male, female, male, female, male, femal…
#> $ age <dbl> 29.0000, 0.9167, 2.0000, 30.0000, 25.0000, 48.0000, 63.0000, …
#> $ sibsp <int> 0, 1, 1, 1, 1, 0, 1, 0, 2, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1…
#> $ parch <int> 0, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1…
#> $ fit <dbl> 31.99903, 22.29190, 22.29190, 22.29190, 22.29190, 31.99903, 3…
Confirm that tidypredict
results match to the
model’s predict()
results
tidypredict
supports the glm
argument as
well:
model <- earth(survived ~ .,
data = etitanic,
glm = list(family = binomial), degree = 2
)
tidypredict_fit(model)
#> 1 - 1/(1 + exp(2.91352600741336 + (ifelse(age > 32, age - 32,
#> 0) * -0.0375714917713104) + (ifelse(pclass == "2nd", 1, 0) *
#> ifelse(sex == "male", 1, 0) * -1.76809447811121) + (ifelse(pclass ==
#> "3rd", 1, 0) * -5.03005595780694) + (ifelse(pclass == "3rd",
#> 1, 0) * ifelse(sibsp < 4, 4 - sibsp, 0) * 0.618652747659846) +
#> (ifelse(pclass == "3rd", 1, 0) * ifelse(sex == "male", 1,
#> 0) * 1.22269536265148) + (ifelse(sex == "male", 1, 0) *
#> -3.18562450248531) + (ifelse(sex == "male", 1, 0) * ifelse(age <
#> 16, 16 - age, 0) * 0.241814028713263)))
The spec sets the is_glm
entry to 1, as well as the
family
and link
entries.
str(parse_model(model), 2)
#> List of 2
#> $ general:List of 6
#> ..$ model : chr "earth"
#> ..$ type : chr "regression"
#> ..$ version: num 2
#> ..$ is_glm : num 1
#> ..$ family : chr "binomial"
#> ..$ link : chr "logit"
#> $ terms :List of 8
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> - attr(*, "class")= chr [1:3] "parsed_model" "pm_regression" "list"
parsnip
fitted models are also supported by
tidypredict
:
library(parsnip)
p_model <- mars(mode = "regression", prod_degree = 3) %>%
set_engine("earth") %>%
fit(age ~ sibsp + parch, data = etitanic)
tidypredict_fit(p_model)
#> 22.2918960405404 + (ifelse(parch < 2, 2 - parch, 0) * 4.85356462114363) +
#> (ifelse(parch > 2, parch - 2, 0) * 13.0493891423278) + (ifelse(parch >
#> 4, parch - 4, 0) * -18.8998708031821) + (ifelse(sibsp > 1,
#> sibsp - 1, 0) * -7.71566779782019) + (ifelse(sibsp > 1, sibsp -
#> 1, 0) * ifelse(parch < 1, 1 - parch, 0) * 7.40395975552276) +
#> (ifelse(sibsp > 1, sibsp - 1, 0) * ifelse(parch > 1, parch -
#> 1, 0) * 4.41874354843216)
Here is an example of the model spec:
pm <- parse_model(model)
str(pm, 2)
#> List of 2
#> $ general:List of 6
#> ..$ model : chr "earth"
#> ..$ type : chr "regression"
#> ..$ version: num 2
#> ..$ is_glm : num 1
#> ..$ family : chr "binomial"
#> ..$ link : chr "logit"
#> $ terms :List of 8
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> ..$ :List of 4
#> - attr(*, "class")= chr [1:3] "parsed_model" "pm_regression" "list"
str(pm$terms[1:2])
#> List of 2
#> $ :List of 4
#> ..$ label : chr "(Intercept)"
#> ..$ coef : num 2.91
#> ..$ is_intercept: num 1
#> ..$ fields :List of 1
#> .. ..$ : list()
#> $ :List of 4
#> ..$ label : chr "h(age-32)"
#> ..$ coef : num -0.0376
#> ..$ is_intercept: num 0
#> ..$ fields :List of 1
#> .. ..$ :List of 4
#> .. .. ..$ type: chr "operation"
#> .. .. ..$ col : chr "age"
#> .. .. ..$ val : num 32
#> .. .. ..$ op : chr "morethan"