Skip to content

MutualInformation

MutualInformation(dataset, column_list, filter_condition_dict=None, duckdb_connection=None, decimal_places=6)

Calculate Mutual Information (MI) between pairs of columns using DuckDB.

Computes the mutual information metric between specified column pairs to measure their statistical dependence. Supports both categorical and continuous variables with optional filtering conditions.

Details

The function calculates mutual information using the formula: MI(X,Y) = Σ P(x,y) * log(P(x,y)/(P(x)*P(y)))

Where: - P(x,y) is the joint probability - P(x) and P(y) are marginal probabilities - The sum is over all possible value pairs

Higher MI values indicate stronger relationships between variables: - MI = 0: Variables are independent - MI > 0: Variables share information - Higher values suggest stronger dependencies

Example

Consider salary and department columns: Salary: [50000, 60000, 50000, 75000, 55000] Dept: ['IT', 'HR', 'IT', 'Fin', 'IT']

Example calculation: 1. Calculate joint probabilities P(salary,dept) 2. Calculate marginal probabilities P(salary) and P(dept) 3. Compute MI using the formula above

Code example:

df = pl.DataFrame({ ... 'salary': [50000, 60000, 50000, 75000, 55000], ... 'department': ['IT', 'HR', 'IT', 'Finance', 'IT'], ... 'age': [25, 30, 25, 35, 28] ... }) result = MutualInformation(df, [{'salary': 'department'}]) print(result) [{'columns': 'salary,department', 'mutual_information': 0.682345, 'table_name': 'mutual_info_abc123', 'execution_timestamp_utc': '2024-01-25 10:30:45', 'filter_conditions': None}]

Parameters

dataset : Any Input dataset (DataFrame or table name). Can be: - Polars DataFrame - Pandas DataFrame - PyArrow Table - String representing existing table name in DuckDB connection The dataset must contain all columns specified in column_list.

List[Dict[str, str]]

List of column pairs to analyze. Each pair is a single-item dictionary where key and value are column names. Example: [{'salary': 'department'}, {'age': 'experience'}] Both categorical and numeric columns are supported. Columns must exist in dataset and contain non-null values.

Optional[Dict[str, Union[str, int, float]]]

Row filter conditions to apply before MI calculation. Example: {'department': 'IT', 'age': 25} Keys must be valid column names. Values must match column data types. Default: None (no filtering)

Optional[DuckDBPyConnection]

Existing DuckDB connection to use. If None, creates temporary connection. Connection must have access to dataset if table name provided. Default: None

int

Number of decimal places for MI values. Must be non-negative integer. Affects precision of returned MI values. Default: 6

Returns

List[Dict[str, Union[str, float, dict, None]]] Analysis results for each column pair: - columns : str Comma-separated column pair names (e.g., "salary,department") - mutual_information : float Calculated MI value rounded to specified decimal places - table_name : str Name of analyzed table - execution_timestamp_utc : str UTC timestamp of execution - filter_conditions : Optional[Dict] Applied filter conditions if any, else None

Raises

ValueError - Empty or invalid column_list format - Column not found in dataset - Invalid decimal_places (negative) - Invalid filter column names - Type mismatch in filter conditions

Source code in src/whistlingduck/analyzers/MutualInformation.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def MutualInformation(dataset: Any,
                     column_list: List[Dict[str, str]],
                     filter_condition_dict: Optional[Dict[str, Union[str, int, float]]] = None,
                     duckdb_connection: Optional[DuckDBPyConnection] = None,
                     decimal_places: int = 6
                    ) -> List[Dict[str, Union[str, float, dict, None]]]:
    """
    Calculate Mutual Information (MI) between pairs of columns using DuckDB.

    Computes the mutual information metric between specified column pairs to measure
    their statistical dependence. Supports both categorical and continuous variables
    with optional filtering conditions.

    Details
    -------
    The function calculates mutual information using the formula:
        MI(X,Y) = Σ P(x,y) * log(P(x,y)/(P(x)*P(y)))

    Where:
    - P(x,y) is the joint probability
    - P(x) and P(y) are marginal probabilities
    - The sum is over all possible value pairs

    Higher MI values indicate stronger relationships between variables:
    - MI = 0: Variables are independent
    - MI > 0: Variables share information
    - Higher values suggest stronger dependencies

    Example
    -------
    Consider salary and department columns:
    Salary: [50000, 60000, 50000, 75000, 55000]
    Dept:   ['IT',  'HR',   'IT',   'Fin',  'IT']

    Example calculation:
    1. Calculate joint probabilities P(salary,dept)
    2. Calculate marginal probabilities P(salary) and P(dept)
    3. Compute MI using the formula above

    Code example:
    >>> df = pl.DataFrame({
    ...     'salary': [50000, 60000, 50000, 75000, 55000],
    ...     'department': ['IT', 'HR', 'IT', 'Finance', 'IT'],
    ...     'age': [25, 30, 25, 35, 28]
    ... })
    >>> result = MutualInformation(df, [{'salary': 'department'}])
    >>> print(result)
    [{'columns': 'salary,department', 
      'mutual_information': 0.682345,
      'table_name': 'mutual_info_abc123',
      'execution_timestamp_utc': '2024-01-25 10:30:45',
      'filter_conditions': None}]

    Parameters
    ----------
    dataset : Any
        Input dataset (DataFrame or table name). Can be:
        - Polars DataFrame
        - Pandas DataFrame
        - PyArrow Table
        - String representing existing table name in DuckDB connection
        The dataset must contain all columns specified in column_list.

    column_list : List[Dict[str, str]]
        List of column pairs to analyze. Each pair is a single-item dictionary
        where key and value are column names.
        Example: [{'salary': 'department'}, {'age': 'experience'}]
        Both categorical and numeric columns are supported.
        Columns must exist in dataset and contain non-null values.

    filter_condition_dict : Optional[Dict[str, Union[str, int, float]]]
        Row filter conditions to apply before MI calculation.
        Example: {'department': 'IT', 'age': 25}
        Keys must be valid column names.
        Values must match column data types.
        Default: None (no filtering)

    duckdb_connection : Optional[DuckDBPyConnection]
        Existing DuckDB connection to use.
        If None, creates temporary connection.
        Connection must have access to dataset if table name provided.
        Default: None

    decimal_places : int
        Number of decimal places for MI values.
        Must be non-negative integer.
        Affects precision of returned MI values.
        Default: 6

    Returns
    -------
    List[Dict[str, Union[str, float, dict, None]]]
        Analysis results for each column pair:
        - columns : str
            Comma-separated column pair names (e.g., "salary,department")
        - mutual_information : float
            Calculated MI value rounded to specified decimal places
        - table_name : str
            Name of analyzed table
        - execution_timestamp_utc : str
            UTC timestamp of execution
        - filter_conditions : Optional[Dict]
            Applied filter conditions if any, else None

    Raises
    ------
    ValueError
        - Empty or invalid column_list format
        - Column not found in dataset
        - Invalid decimal_places (negative)
        - Invalid filter column names
        - Type mismatch in filter conditions
    """
    # Validate decimal_places
    if decimal_places < 0:
        raise ValueError(
            "decimal_places must be non-negative. "
            "Please provide a value >= 0."
        )

    # Validate column_list input
    if not isinstance(column_list, list) or not column_list:
        raise ValueError(
            "column_list must be a non-empty list of dictionaries. "
            "Example: [{'column1': 'column2'}]"
        )

    # Process and validate column pairs
    processed_pairs = []
    for idx, pair in enumerate(column_list):
        if not isinstance(pair, dict) or len(pair) != 1:
            raise ValueError(
                f"Item at index {idx} must be a dictionary with one key-value pair. "
                "Example: {'column1': 'column2'}"
            )
        col1, col2 = next(iter(pair.items()))
        processed_pairs.append({'col1': col1, 'col2': col2})

    # Generate UUID for table name and get UTC timestamp
    unique_id = str(uuid.uuid4()).replace('-', '_')
    timestamp = datetime.now(timezone.utc)
    temp_table_name = f"mutual_info_{unique_id}"

    # Handle DuckDB connection and table registration
    if duckdb_connection is None:
        con = duckdb.connect()
        try:
            con.register(temp_table_name, dataset)
            source_table = temp_table_name
        except Exception as e:
            con.close()
            raise ValueError(
                f"Failed to register dataset: {str(e)}. "
                "Please ensure the dataset is in a DuckDB-compatible format."
            )
    else:
        con = duckdb_connection
        if isinstance(dataset, str):
            try:
                con.sql(f"PRAGMA table_info('{dataset}')")
                source_table = dataset
            except duckdb.CatalogException:
                raise ValueError(f"Table '{dataset}' does not exist in the DuckDB connection")
        else:
            try:
                con.register(temp_table_name, dataset)
                source_table = temp_table_name
            except Exception as e:
                raise ValueError(
                    f"Failed to register dataset with existing connection: {str(e)}. "
                    "Please ensure the dataset is in a DuckDB-compatible format."
                )

    # Get table info for column validation
    dtype_info = con.sql(f"PRAGMA table_info('{source_table}')").pl()
    dataset_columns = dtype_info['name'].to_list()

    # Validate columns existence
    for pair in processed_pairs:
        invalid_cols = []
        for col in [pair['col1'], pair['col2']]:
            if col not in dataset_columns:
                invalid_cols.append(col)
        if invalid_cols:
            if duckdb_connection is None:
                con.close()
            raise ValueError(
                f"These columns were not found in the dataset: {', '.join(invalid_cols)}. "
                "Please verify the column names."
            )

    # Handle filter conditions
    if filter_condition_dict:
        if not isinstance(filter_condition_dict, dict):
            if duckdb_connection is None:
                con.close()
            raise ValueError(
                "filter_condition_dict must be a dictionary. "
                "For single filter condition, use {'column_name': value}."
            )

        invalid_filter_cols = list(set(filter_condition_dict.keys()) - set(dataset_columns))
        if invalid_filter_cols:
            if duckdb_connection is None:
                con.close()
            raise ValueError(
                f"We couldn't find these columns in your dataset: {', '.join(invalid_filter_cols)}. "
                "Please verify the column names in your filter conditions."
            )

        where_clause = "WHERE " + " AND ".join(
            f"{col} = '{val}'" if isinstance(val, str) else f"{col} = {val}"
            for col, val in filter_condition_dict.items()
        )
    else:
        where_clause = ""

    # Generate CTEs for all pairs
    cte_parts = []
    select_parts = []

    for idx, pair in enumerate(processed_pairs):
        col1, col2 = pair['col1'], pair['col2']
        cte_name = f"base_{idx}"
        prob_name = f"prob_{idx}"

        cte_parts.extend([
            f"""
            {cte_name} AS (
                SELECT {col1}, {col2},
                       COUNT(*)::FLOAT as joint_count,
                       COUNT(*) OVER()::FLOAT as total_count
                FROM {source_table}
                {where_clause}
                GROUP BY {col1}, {col2}
            ),
            {prob_name} AS (
                SELECT 
                    joint_count / total_count as joint_prob,
                    SUM(joint_count) OVER(PARTITION BY {col1}) / total_count as prob_a,
                    SUM(joint_count) OVER(PARTITION BY {col2}) / total_count as prob_b
                FROM {cte_name}
            )"""
        ])

        select_parts.append(
            f"""
            SELECT 
                '{col1},{col2}' as columns,
                ROUND(
                    SUM(
                        CASE 
                            WHEN joint_prob > 0 THEN 
                                joint_prob * LOG(joint_prob / NULLIF(prob_a * prob_b, 0))
                            ELSE 0 
                        END
                    ), {decimal_places}
                ) as mutual_information
            FROM {prob_name}"""
        )

    # Combine all CTEs and SELECT statements
    final_query = f"""
    WITH {', '.join(cte_parts)}
    {' UNION ALL '.join(select_parts)}
    """

    # Execute query and get results
    try:
        result = con.sql(final_query).pl()
    except Exception as e:
        if duckdb_connection is None:
            con.close()
        raise ValueError(f"Error executing query: {str(e)}")

    # Close connection if created internally
    if duckdb_connection is None:
        con.close()

    # Process results
    results = result.to_dicts()
    for result in results:
        result.update({
            'table_name': source_table,
            'execution_timestamp_utc': timestamp.strftime("%Y-%m-%d %H:%M:%S"),
            'filter_conditions': filter_condition_dict if filter_condition_dict else None
        })

    return results