Source code for linreg_ally.multicollinearity

# Author: Alex Wong
# multicollinearity.py
# 01/10/2025

import altair as alt
import pandas as pd
import numpy as np
from statsmodels.stats.outliers_influence import variance_inflation_factor 

[docs] def check_multicollinearity(train_df: pd.DataFrame, threshold = None, vif_only = False): """ Detects multicollinearity in the training dataset by computing the variance inflation factor (‘VIF’) and pairwise Pearson Correlation for each numeric feature. Parameters ---------- train_df : pd.DataFrame Training dataset threshold : int Minimum threshold of VIF for a feature to be included in the returned dataframe. Default is None. vif_only : Boolean If true, only a dataframe containing the VIF scores will be returned. Otherwise, the correlation chart is also returned. Returns ------- pd.DataFrame A dataframe containing the VIF of all numeric features in train_df. alt.Chart A chart that shows the pairwise Pearson Correlations of all numeric columns in train_df. Raises ------ TypeError If `train_df` is not a pandas DataFrame. Examples -------- >>> from linreg_ally.multicollinearity import check_multicollinearity >>> vif_df, corr_chart = check_multicollinearity(train_df) >>> vif_df = check_multicollinearity(train_df, threshold = 5, vif_only = True) """ if not isinstance(train_df, pd.DataFrame): raise TypeError(f"Expect train_df to be a pd.Dataframe but got {type(train_df)}") # select only numeric columns in train_df train_df_numeric_only = train_df.select_dtypes(include='number') # Calculate VIF for each feature vif = [variance_inflation_factor(train_df_numeric_only, i) for i in range(len(train_df_numeric_only.columns))] vif_dict = { 'Features': train_df_numeric_only.columns, 'VIF': vif } vif_df = pd.DataFrame(vif_dict) # Calculate pairwise Pearson correlations corr_df = (train_df_numeric_only .corr('pearson', numeric_only=True) .abs() # Use abs for negative correlation to stand out .stack() # Get df into long format for altair .reset_index(name='corr')) # Round the correlation values to 3 decimal places corr_df['corr'] = corr_df['corr'].round(3) # Create the correlation chart using Altair corr_chart = alt.Chart(corr_df).mark_circle().encode( x='level_0', y='level_1', size='corr', color='corr', tooltip=['level_0', 'level_1', 'corr'] ) # Filter VIF dataframe by the threshold if provided if threshold is not None: vif_df = vif_df[vif_df['VIF'] >= threshold] # Return VIF dataframe or both VIF dataframe and correlation chart if vif_only: return vif_df else: return vif_df, corr_chart