# Source code for dit.rate_distortion.curves

"""
Objects to compute single rate-distortion curves.
"""

from __future__ import division

import numpy as np

from .blahut_arimoto import blahut_arimoto, blahut_arimoto_ib
from .distortions import hamming
from .information_bottleneck import InformationBottleneck, InformationBottleneckDivergence
from .. import Distribution
from ..algorithms.minimal_sufficient_statistic import mss
from ..exceptions import ditException
from ..multivariate import entropy, total_correlation
from ..utils import flatten

[docs]class RDCurve(object):
"""
Compute a rate-distortion curve.
"""

def __init__(self, dist, rv=None, crvs=None, beta_min=0, beta_max=10, beta_num=101, alpha=1.0, distortion=hamming, method=None):
"""
Initialize the curve computer.

Parameters
----------
dist : Distribution
The distribution of interest.
rv : iterable, None
The random variables to compute the rate-distortion curve of.
If None, use all.
crvs : iterable, None
The random variables to condition on.
beta_min : float
The minimum beta value for the curve. Defaults to 0.
beta_max : float
The maximum beta value for the curve. Defaults to 10. If None,
iteratively find a beta value with nearly maximal rate.
beta_num : int
The number of beta values for the curve. Defaults to 101.
alpha : float
The alpha value to utilize. 1.0 corresponds to the standard information
bottleneck, while 0.0 corresponds to the deterministic bottleneck.
distortion : Distortion
The distortion to use.
method : {'sp', 'ba', None}
The method to utilize in computing the curve. If 'sp', utilize
scipy.optimize; if 'ba' utilize the iterative Blahut-Arimoto
algorithm. Defaults to None, in which case 'sp' is used if
distortion supports it, and 'ba' if not.

Raises
------
ditException
Raised if any of the parameters are not viable.
"""
if rv is None:
rv = list(flatten(dist.rvs))

self.dist = dist.copy()
self.rv = rv
self.crvs = crvs

d = dist.coalesce([self.rv])
self.p_x = d.pmf

self._distortion = distortion

if method is None:
if distortion.optimizer:
method = 'sp'
elif distortion.matrix:  # pragma: no cover
method = 'ba'
else:  # pragma: no cover
msg = "Distortion measure is vacuous."
raise ditException(msg)
elif method not in ('sp', 'ba'):  # pragma: no cover
msg = "Method '{}' not supported.".format(method)
raise ditException(msg)
elif method == 'sp' and not distortion.optimizer:  # pragma: no cover
msg = "Method is 'sp' but distortion does not have an optimizer."
raise ditException(msg)
elif method == 'ba' and not distortion.matrix:  # pragma: no cover
msg = "Method is 'ba' but distortion does not have a matrix."
raise ditException(msg)
elif method == 'ba' and crvs:  # pragma: no cover
msg = "Method 'ba' does not support conditional variables."
raise ditException(msg)

self._get_rd = {'ba': self._get_rd_ba,
'sp': self._get_rd_sp,
}[method]

self._rd_opt = self._distortion.optimizer(self.dist,
beta=0.0,
alpha=alpha,
rv=self.rv,
crvs=self.crvs)

self._max_rate = entropy(d)
_, self._max_distortion, _, _ = self._get_rd(beta=0.0)
self._max_rank = len(d.outcomes)

if beta_max is None:
beta_max = self.find_max_beta()
self.betas = np.linspace(beta_min, beta_max, beta_num)

try:  # pragma: no cover
dist_name = [dist.name]
except AttributeError:
dist_name = []
self.label = " ".join(dist_name + [self._distortion.name])

self.compute()

def __add__(self, other):  # pragma: no cover
"""
Combine two RDCurves into an RDPlotter.

Parameters
----------
other : RDCurve
The curve to aggregate with self.

Returns
-------
plotter : RDPlotter
A plotter with both self and other.
"""
from .plotting import RDPlotter
if isinstance(other, RDCurve):
plotter = RDPlotter(self, other)
return plotter
else:
return NotImplemented

def find_max_beta(self):
"""
Find a beta value which maximizes the rate.

Returns
-------
beta_max : float
The the smallest found beta value which achieves minimal
distortion.
"""
beta_max = 1
rate = 0

while not np.isclose(rate, self._max_rate, atol=1e-5, rtol=1e-5):
beta_max = 1.5*beta_max
rate, _, _, _ = self._get_rd(beta=beta_max)

return beta_max

def _get_rd_sp(self, beta, initial=None):
"""
Compute the rate-distortion pair for beta using scipy.optimize.

Parameters
----------
beta : float
The beta value to optimize for.
initial : np.ndarray, None
An initial optimization vector, useful for numerical continuation.

Returns
-------
r : float
The rate.
d : float
The distortion.
q : np.ndarray
The matrix p(x, x_hat)
x0 : np.ndarray
The found optima.
"""
self._rd_opt._beta = beta
self._rd_opt.optimize(x0=initial)
x0 = self._rd_opt._optima.copy()
q = self._rd_opt.construct_joint(self._rd_opt._optima)
r = self._rd_opt.rate(q)
d = self._rd_opt.distortion(q)
return r, d, q.sum(axis=1), x0

def _get_rd_ba(self, beta, initial=None):
"""
Compute the rate-distortion pair for beta using Blahut-Arimoto.

Parameters
----------
beta : float
The beta value to optimize for.
initial : np.ndarray, None
An initial optimization vector, useful for numerical continuation.

Returns
-------
r : float
The rate.
d : float
The distortion.
q : np.ndarray
The matrix p(x, x_hat)
x0 : np.ndarray
The found optima.
"""
(r, d), q = blahut_arimoto(p_x=self.p_x,
beta=beta,
distortion=self._distortion.matrix,
)
return r, d, q, initial

def compute(self):
"""
Sweep beta and compute the rate-distortion curve.

Parameters
----------
method : {'sp', 'ba'}
The method of computation to use. 'sp' denotes scipy.optimize;
'ba' denotes blahut-arimoto.
"""
rates = []
distortions = []
ranks = []
alphabets = []

x0 = None

for beta in self.betas[::-1]:
r, d, q, x0 = self._get_rd(beta, initial=x0)
rates.append(r)
distortions.append(d)

q_x_xhat = q / q.sum(axis=0, keepdims=True)

ranks.append(np.linalg.matrix_rank(q_x_xhat, tol=1e-5))
alphabets.append((q.sum(axis=0) > 1e-6).sum())

self.rates = np.asarray(rates)[::-1]
self.distortions = np.asarray(distortions)[::-1]
self.ranks = np.asarray(ranks)[::-1]
self.alphabets = np.asarray(alphabets)[::-1]

def plot(self, downsample=5):  # pragma: no cover
"""
Construct an RDPlotter and utilize it to plot the rate-distortion
curve.

Parameters
----------
downsample : int
The how frequent to display points along the RD curve.

Returns
-------
fig : plt.figure
The resulting figure.
"""
from .plotting import RDPlotter
plotter = RDPlotter(self)
return plotter.plot(downsample)

[docs]class IBCurve(object):
"""
Compute an information bottleneck curve.
"""

def __init__(self, dist, rvs=None, crvs=None, rv_mode=None, beta_min=0.0, beta_max=15.0, beta_num=101, alpha=1.0, method='sp', divergence=None):
"""
Initialize the curve computer.

Parameters
----------
dist : Distribution
The distribution of interest.
rv : iterable, None
The random variables to compute the information bottleneck curve of.
If None, use [0], [1].
crvs : iterable, None
The random variables to condition on.
rv_mode : str, None
Specifies how to interpret rvs and crvs. Valid options are:
{'indices', 'names'}. If equal to 'indices', then the elements of
crvs and rvs are interpreted as random variable indices. If
equal to 'names', the the elements are interpreted as random
variable names. If None, then the value of dist._rv_mode is
consulted, which defaults to 'indices'.
beta_min : float
The minimum beta value for the curve. Defaults to 0.
beta_max : float, None
The maximum beta value for the curve. Defaults to 15. If None,
iteratively find a beta value with nearly maximal complexity.
beta_num : int
The number of beta values for the curve. Defaults to 101.
alpha : float
The alpha value to utilize. 1.0 corresponds to the standard information
bottleneck, while 0.0 corresponds to the deterministic bottleneck.
method : {'sp', 'ba'}
The method to utilize in computing the curve. If 'sp', utilize
scipy.optimize; if 'ba' utilize the iterative Blahut-Arimoto
algorithm. Defaults to 'sp'.
divergence : func
The divergence measure to use as a distortion. Defaults to the standard
relative entropy.
"""
self.dist = dist.copy()
self.dist.make_dense()

self._x, self._y = rvs if rvs is not None else ([0], [1])
self._z = crvs if crvs is not None else []
self._aux = [dist.outcome_length()]
self._rv_mode = rv_mode

self.p_xy = self.dist.coalesce([self._x, self._y])
self.p_xy = self.p_xy.pmf.reshape(tuple(map(len, self.p_xy.alphabet)))

args = {'dist': self.dist,
'beta': 0.0,
'alpha': alpha,
'rvs': [self._x, self._y],
'crvs': self._z,
'rv_mode': self._rv_mode
}

if divergence is not None:  # pragma: no cover
bottleneck = InformationBottleneckDivergence
args['divergence'] = divergence
else:
bottleneck = InformationBottleneck
self._bn = bottleneck(**args)

self._max_complexity = entropy(mss(dist, self._x, self._y))
self._max_relevance = total_correlation(dist, [self._x, self._y])
self._max_rank = len(dist.marginal(self._x).outcomes)
self._max_distortion = self._bn.distortion(self._get_opt_sp(beta=0.0)[0])

if np.isclose(alpha, 1.0):
self.label = "IB"
elif np.isclose(alpha, 0.0):
self.label = "DIB"
else:
self.label = "GIB({:.3f})".format(alpha)

beta_max = self.find_max_beta() if beta_max is None else beta_max
self.betas = np.linspace(beta_min, beta_max, beta_num)

self.compute(method)

def __add__(self, other):  # pragma: no cover
"""
Combine two IBCurves into an IBPlotter.

Parameters
----------
other : IBCurve
The curve to aggregate with self.

Returns
-------
plotter : IBPlotter
A plotter with both self and other.
"""
from .plotting import IBPlotter
if isinstance(other, IBCurve):
plotter = IBPlotter(self, other)
return plotter
else:
return NotImplemented

def _get_opt_sp(self, beta, initial=None):
"""
Compute the information bottleneck solution for beta using scipy.optimize.

Parameters
----------
beta : float
The beta value to optimize for.
initial : np.ndarray, None
An initial optimization vector, useful for numerical continuation.

Returns
-------
q : np.ndarray
The matrix p(x, y, z, t)
x0 : np.ndarray
The found optima.
"""
self._bn._beta = beta
self._bn.optimize(x0=initial)
x0 = self._bn._optima.copy()
q_xyzt = self._bn.construct_joint(self._bn._optima)
return q_xyzt, x0

def _get_opt_ba(self, beta, initial=None):  # pragma: no cover
"""
Compute the information bottleneck solution for beta using blahut-arimoto.

Parameters
----------
beta : float
The beta value to optimize for.
initial : np.ndarray, None
An initial optimization vector, useful for numerical continuation.

Returns
-------
q : np.ndarray
The matrix p(x, y, z, t)
x0 : np.ndarray
The found optima.
"""
q_xyt = blahut_arimoto_ib(p_xy=self.p_xy, beta=beta)[1]
q_xyzt = q_xyt[:, :, np.newaxis, :]
return q_xyzt, None

def compute(self, method='sp'):
"""
Sweep beta and compute the information bottleneck curve.

Parameters
----------
method : {'sp', 'ba'}
The method of computation to use. 'sp' denotes scipy.optimize;
'ba' denotes blahut-arimoto.
"""
get_opt = {'ba': self._get_opt_ba,
'sp': self._get_opt_sp,
}[method]

complexities = []
entropies = []
relevances = []
errors = []
ranks = []
alphabets = []
distortions = []

x, y, z, t = [[0], [1], [2], [3]]

x0 = None

for beta in self.betas[::-1]:
q_xyzt, x0 = get_opt(beta, x0)
d = Distribution.from_ndarray(q_xyzt)
complexities.append(total_correlation(d, [x, t], z))
entropies.append(entropy(d, x, z))
relevances.append(total_correlation(d, [y, t], z))
errors.append(total_correlation(d, [x, y], z + t))
distortions.append(self._bn.distortion(q_xyzt))

q_xt = q_xyzt.sum(axis=(1, 2))
q_x_t = (q_xt / q_xt.sum(axis=0, keepdims=True))
q_x_t[np.isnan(q_x_t)] = 0

ranks.append(np.linalg.matrix_rank(q_x_t, tol=1e-4))
alphabets.append((q_xt.sum(axis=0) > 1e-6).sum())

self.complexities = np.asarray(complexities)[::-1]
self.entropies = np.asarray(entropies)[::-1]
self.relevances = np.asarray(relevances)[::-1]
self.errors = np.asarray(errors)[::-1]
self.ranks = np.asarray(ranks)[::-1]
self.alphabets = np.asarray(alphabets)[::-1]
self.distortions = np.asarray(distortions)[::-1]

def find_max_beta(self):
"""
Find a beta value which maximizes the rate.

Returns
-------
beta_max : float
The the smallest found beta value which achieves minimal
distortion.
"""
beta_max = 1.0
relevance = 0.0

while not np.isclose(relevance, self._max_relevance, atol=1e-5, rtol=1e-5):
beta_max = 1.5*beta_max
q, _ = self._get_opt_sp(beta=beta_max)
relevance = self._bn.relevance(q)

return beta_max

def find_kinks(self):
"""
Determine the beta values where new features are discovered.

Returns
-------
kinks : np.ndarray
An array of beta values where new features are discovered.
"""
diff = np.diff(self.ranks)
jumps = np.arange(len(diff))[diff > 0]
kinks = np.asarray([jump for jump in jumps if diff[jump-1] == 0])
return self.betas[kinks]

def plot(self, downsample=5):  # pragma: no cover
"""
Construct an IBPlotter and utilize it to plot the information
bottleneck curve.

Parameters
----------
downsample : int
The how frequent to display points along the IB curve.

Returns
-------
fig : plt.figure
The resulting figure.
"""
from .plotting import IBPlotter
plotter = IBPlotter(self)
return plotter.plot(downsample)