Source code for dit.divergences.copy_mutual_information
"""
The copy mutual information, as defined by Kolchinsky & Corominas-Murtra.
"""
from .pmf import relative_entropy
from ..utils import unitful
__all__ = [
'copy_mutual_information',
]
def binary_kullback_leibler_divergence(p, q):
"""
Compute the binary Killback Leibler divergence.
Parameters
----------
p : float
The first probability.
q : float
The second probability.
Returns
-------
dkl : float
The binary Kullback-Leibler divergence.
"""
return relative_entropy([p, 1-p], [q, 1-q])
@unitful
def specific_copy_mutual_information(p_Y_g_x, p_Y, x):
"""
Compute the specific copy mutual information. Roughly it is the
portion of the specific mututal information which results from X = Y = x.
Parameters
----------
p_Y_g_x : Distribution
The probability p(Y|X=x).
p_Y : Distribution
The probability p(Y).
x : event
An event in the sample space of X, Y.
Returns
-------
Icopy : float
The specific copy mutual information of x.
"""
py = p_Y[x]
pygx = p_Y_g_x[x]
if pygx > py:
return binary_kullback_leibler_divergence(pygx, py)
else:
return 0
[docs]def copy_mutual_information(dist, X, Y, rv_mode=None):
"""
Computes the copy mutual information. Roughly, it is the
portion of the mutual information which results from X = Y.
Parameters
----------
dist : Distribution
The distribution of interest.
X : iterable
The indicies to consider as X.
Y : iterable
The indicies to consider as Y.
rv_mode : str, None
Specifies how to interpret ``crvs`` and ``rvs``. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements
of ``crvs`` and ``rvs`` are interpreted as random variable indices.
If equal to 'names', the the elements are interpreted as random
varible names. If ``None``, then the value of ``self._rv_mode`` is
consulted, which defaults to 'indices'.
Returns
-------
Icopy : float
The copy mutual information of x.
"""
p_Y = dist.marginal(Y, rv_mode=rv_mode)
marg, cdists = dist.condition_on(X, rvs=Y, rv_mode=rv_mode)
return sum([marg[x]*specific_copy_mutual_information(cdist, p_Y, x) for x, cdist in zip(marg.outcomes, cdists)])