A PyTorch library for differentiable two-sample tests

Overview

This package implements a total of six two sample tests:

  • The classical Friedman-Rafsky test [FR79].
  • The classical k-nearest neighbours (kNN) test [FR83].
  • The differentiable Friedman-Rafsky test [DK17].
  • The differentiable k-nearest neighbours (kNN) test [DK17].
  • The maximum mean discrepancy (MMD) test [GBR+12].
  • The energy test [SzekelyR13].

These tests accept as input two samples and produce a statistic that should be large when the samples come from different distributions. The first two of these are not differentiable, and can be only used for statistical testing, but not for learning implicit generative models.

Incidentally, there is also code for marginal inference in models with cardinality potentials, or spanning-tree constraints, which are internally used for the implementation of the smooth graph tests. The code for inference in cardinality potentials has been adapted from the code accompanying [SST+12], while the original algorithm comes from [TSZ+12]. Please consider citing these papers if you use the smoothed k-NN test, which relies on this inference method.

Installation

Once you install PyTorch (following these instructions) , you can install the package as:

python setup.py install

Usage

Statistical testing

First, one has to create the object that implements the test. For example:

>>> n = 128  # Sample size.
>>> fr_test = SmoothFRStatistic(n, n, cuda=False, compute_t_stat=True)

Note that we have used compute_t_stat=True, which means that fr_test.__call__ will compute a t-statistic, so that we can obtain an approximate p-value from it using the CDF of a standard normal. We have implemented t-statistics only for the smooth graph tests.

Let us work out a concrete example. We will draw three different samples from 5 dimensional Gaussians, so that the first two come from the same distribution, while the last one from a distribution with a different mean.

>>> # We fix the seeds so that the result is reproducible.
>>> torch.manual_seed(0)
>>> numpy.random.seed(0)
>>> dim = 5  # The data dimension.
>>> sample_1 = torch.randn(n, dim)
>>> sample_2 = torch.randn(n, dim)
>>> sample_3 = torch.randn(n, dim)
>>> sample_3[:, 0] += 1  # Shift the first coordinate by one.

We can now compute the t-statistics (the returned matrices will be discussed shortly).

>>> t_val_12, matrix_12 = fr_test(Variable(sample_1), Variable(sample_2),
>>>                               alphas=[4.], ret_matrix=True)
>>> t_val_13, matrix_13 = fr_test(Variable(sample_1), Variable(sample_3),
>>>                               alphas=[4.], ret_matrix=True)

The approximate p-values can be then computed from the normal’s CDF as follows:

>>> from scipy.stats import norm   # To evaluate the CDF of a normal.
>>> print('1 vs 2', 1 - norm.cdf(t_val_12.data[0]))
1 vs 2 0.834159503205
>>> print('1 vs 3', 1 - norm.cdf(t_val_13.data[0]))
1 vs 3 1.31049615604e-11

We also provide the means to compute the p-value by sampling from the permutation null. Namely, if you provide set ret_matrix=True in the invocation of fr_test, the function returns a second return value, which can be then used to compute the p-value. Every test has a pval method, which accepts this second return value and returns the p-value estimated from random permutations. We would like to point out that this method has been written in cython, so that it should execute reasonably fast. Concretely, we can compute the (sampled) exact p-value as follows.

>>> np.seed(0)  # The method is internally using numpy.
>>> print('1 vs 2', fr_test.pval(matrix_12, n_permutations=1000))
1 vs 2 0.853999972343
>>> print('1 vs 3', fr_test.pval(matrix_13, n_permutations=1000))
1 vs 3 0.0

Implicit model learning

We go through a simple example of learning an implicit generative model on MNIST. While here we present only the minimal code necessary to train a model, we provide in our repository a jupyter notebook that builds on the code below and should be directly executable on your machine. As the base measure we will use a 10-dimensional Gaussian, and the following generative model:

>>> generator = nn.Sequential(
>>>   nn.Linear(10, 128),  # Receive a 10 dimensional noise vector as input.
>>>   nn.ReLU(),  # Then, a single hidden layer of 10 units with ReLU.
>>>   nn.Linear(128, 28 * 28),  # The output has 28 * 28 dimensions.
>>>   nn.Sigmoid())  # Squash the output to [0, 1].

To optimize we will use the torch.optim.Adam optimizer [KB15].

>>> optimizer = torch.optim.Adam(generator.parameters(), lr=1e-3)

As a loss function we will use the smoothed 1-NN loss with a batch size of 256.

>>> batch_size = 256
>>> loss_fn = SmoothKNNStatistic(
>>>    batch_size, batch_size, False, 1, compute_t_stat=True)

Next, let us load the MNIST dataset using torchvision.

>>> from torchvision.datasets import MNIST
>>> from torchvision.transforms import ToTensor
>>> dataset = MNIST('mnist_dir', transform=ToTensor(), download=True)

We can then train the model for 50 epochs.

>>> from torch.utils.data import DataLoader
>>>
>>> for epoch in range(1, 51):
>>>   data_loader = DataLoader(dataset, batch_size=batch_size, drop_last=True)
>>>   # Note that we drop the last batch as they all have to be of equal size.
>>>   noise_tensor = torch.FloatTensor(batch_size, 10)
>>>   losses = []
>>>   noise = Variable(noise_tensor)
>>>   for batch, _ in data_loader:
>>>     batch = batch.view(batch_size, -1)  # We want one observation per row.
>>>     noise_tensor.normal_()
>>>     optimizer.zero_grad()
>>>     loss = loss_fn(Variable(batch), generator(noise), alphas=[0.1])
>>>     loss.backward()
>>>     losses.append(loss.data[0])
>>>     optimizer.step()
>>>   print('epoch {0:>2d}, avg loss {1}'.format(epoch, np.mean(losses)))

Statistics

Differentiable statistics

These tests can be used for both learning implicit models and statistical two sample testing.

class torch_two_sample.statistics_diff.SmoothFRStatistic(n_1, n_2, cuda, compute_t_stat=True)[source]

The smoothed Friedman-Rafsky test [DK17].

Parameters:
  • n_1 (int) – The number of points in the first sample.
  • n_2 (int) – The number of points in the second sample.
  • cuda (bool) – If true, the arguments to __call__() must be be on the current cuda device. Otherwise, they should be on the cpu.
__call__(sample_1, sample_2, alphas, norm=2, ret_matrix=False)[source]

Evaluate the smoothed Friedman-Rafsky test statistic.

The test accepts several inverse temperatures in alphas, does one test for each alpha, and takes their mean as the statistic. Namely, using the notation in [DK17], the value returned by this call if compute_t_stat=False is equal to:

\[-\frac{1}{m}\sum_{j=m}^k T_{\pi^*}^{1/\alpha_j}(\textrm{sample}_1, \textrm{sample}_2).\]

If compute_t_stat=True, the returned value is the t-statistic of the above quantity under the permutation null. Note that we compute the negated statistic of what is used in [DK17], as it is exactly what we want to minimize when used as an objective for training implicit models.

Parameters:
  • sample_1 (torch.autograd.Variable) – The first sample, should be of size (n_1, d).
  • sample_2 (torch.autograd.Variable) – The second sample, should be of size (n_2, d).
  • alphas (list of float numbers) – The inverse temperatures.
  • norm (float) – Which norm to use when computing distances.
  • ret_matrix (bool) –

    If set, the call with also return a second variable.

    This variable can be then used to compute a p-value using pval().

Returns:

  • float – The test statistic, a t-statistic if compute_t_stat=True.
  • torch.autograd.Variable – Returned only if ret_matrix was set to true.

pval(matrix, n_permutations=1000)[source]

Compute a p-value using a permutation test.

Parameters:
Returns:

The estimated p-value.

Return type:

float

class torch_two_sample.statistics_diff.SmoothKNNStatistic(n_1, n_2, cuda, k, compute_t_stat=True)[source]

The smoothed k-nearest neighbours test [DK17].

Note that the k=1 case is computed directly using a SoftMax and should execute much faster than the statistics with k > 1.

Parameters:
  • n_1 (int) – The number of points in the first sample.
  • n_2 (int) – The number of points in the second sample.
  • cuda (bool) – If true, the arguments to __call__ must be be on the current cuda device. Otherwise, they should be on the cpu.
  • k (int) – The number of nearest neighbours (k in kNN).
__call__(sample_1, sample_2, alphas, norm=2, ret_matrix=False)[source]

Evaluate the smoothed kNN statistic.

The test accepts several inverse temperatures in alphas, does one test for each alpha, and takes their mean as the statistic. Namely, using the notation in [DK17], the value returned by this call if compute_t_stat=False is equal to:

\[-\frac{1}{m}\sum_{j=m}^k T_{\pi^*}^{1/\alpha_j}(\textrm{sample}_1, \textrm{sample}_2).\]

If compute_t_stat=True, the returned value is the t-statistic of the above quantity under the permutation null. Note that we compute the negated statistic of what is used in [DK17], as it is exactly what we want to minimize when used as an objective for training implicit models.

Parameters:
  • sample_1 (torch.autograd.Variable) – The first sample, of size (n_1, d).
  • sample_2 (variable of shape (n_2, d)) – The second sample, of size (n_2, d).
  • alpha (list of float) – The smoothing strengths.
  • norm (float) – Which norm to use when computing distances.
  • ret_matrix (bool) –

    If set, the call with also return a second variable.

    This variable can be then used to compute a p-value using pval().

Returns:

  • float – The test statistic, a t-statistic if compute_t_stat=True.
  • torch.autograd.Variable – Returned only if ret_matrix was set to true.

pval(margs, n_permutations=1000)[source]

Compute a p-value using a permutation test.

Parameters:
Returns:

The estimated p-value.

Return type:

float

class torch_two_sample.statistics_diff.MMDStatistic(n_1, n_2)[source]

The unbiased MMD test of [GBR+12].

The kernel used is equal to:

\[k(x, x') = \sum_{j=1}^k e^{-\alpha_j\|x - x'\|^2},\]

for the \(\alpha_j\) proved in __call__().

Parameters:
  • n_1 (int) – The number of points in the first sample.
  • n_2 (int) – The number of points in the second sample.
__call__(sample_1, sample_2, alphas, ret_matrix=False)[source]

Evaluate the statistic.

The kernel used is

\[k(x, x') = \sum_{j=1}^k e^{-\alpha_j \|x - x'\|^2},\]

for the provided alphas.

Parameters:
  • sample_1 (torch.autograd.Variable) – The first sample, of size (n_1, d).
  • sample_2 (variable of shape (n_2, d)) – The second sample, of size (n_2, d).
  • alphas (list of float) – The kernel parameters.
  • ret_matrix (bool) –

    If set, the call with also return a second variable.

    This variable can be then used to compute a p-value using pval().

Returns:

pval(distances, n_permutations=1000)[source]

Compute a p-value using a permutation test.

Parameters:
Returns:

The estimated p-value.

Return type:

float

class torch_two_sample.statistics_diff.EnergyStatistic(n_1, n_2)[source]

The energy test of [SzekelyR13].

Parameters:
  • n_1 (int) – The number of points in the first sample.
  • n_2 (int) – The number of points in the second sample.
__call__(sample_1, sample_2, ret_matrix=False)[source]

Evaluate the statistic.

Parameters:
  • sample_1 (torch.autograd.Variable) – The first sample, of size (n_1, d).
  • sample_2 (variable of shape (n_2, d)) – The second sample, of size (n_2, d).
  • norm (float) – Which norm to use when computing distances.
  • ret_matrix (bool) –

    If set, the call with also return a second variable.

    This variable can be then used to compute a p-value using pval().

Returns:

pval(distances, n_permutations=1000)[source]

Compute a p-value using a permutation test.

Parameters:
Returns:

The estimated p-value.

Return type:

float

Non-differentiable statistics

These tests can be only used for statistical two-sample testing.

class torch_two_sample.statistics_nondiff.FRStatistic(n_1, n_2)[source]

The classical Friedman-Rafsky test [FR79].

Parameters:
  • n_1 (int) – The number of data points in the first sample.
  • n_2 (int) – The number of data points in the second sample.
__call__(sample_1, sample_2, norm=2, ret_matrix=False)[source]

Evaluate the non-smoothed Friedman-Rafsky test statistic.

Parameters:
  • sample_1 (torch.autograd.Variable) – The first sample, variable of size (n_1, d).
  • sample_2 (torch.autograd.Variable) – The second sample, variable of size (n_1, d).
  • norm (float) – Which norm to use when computing distances.
  • ret_matrix (bool) –

    If set, the call with also return a second variable.

    This variable can be then used to compute a p-value using pval().

Returns:

The number of edges that do connect points from the same sample.

Return type:

float

pval(mst, n_permutations=1000)[source]

Compute a p-value using a permutation test.

Parameters:
Returns:

The estimated p-value.

Return type:

float

class torch_two_sample.statistics_nondiff.KNNStatistic(n_1, n_2, k)[source]

The classical k-NN test [FR83].

Parameters:
  • n_1 (int) – The number of data points in the first sample.
  • n_2 (int) – The number of data points in the second sample
  • k (int) – The number of nearest neighbours (k in kNN).
__call__(sample_1, sample_2, norm=2, ret_matrix=False)[source]

Evaluate the non-smoothed kNN test statistic.

Parameters:
  • sample_1 (torch.autograd.Variable) – The first sample, variable of size (n_1, d).
  • sample_2 (torch.autograd.Variable) – The second sample, variable of size (n_1, d).
  • norm (float) – Which norm to use when computing distances.
  • ret_matrix (bool) –

    If set, the call with also return a second variable.

    This variable can be then used to compute a p-value using pval().

Returns:

  • float – The number of edges that connect points from the same sample.
  • torch.autograd.Variable (optional) – Returned only if ret_matrix was set to true.

pval(margs, n_permutations=1000)[source]

Compute a p-value using a permutation test.

Parameters:
Returns:

The estimated p-value.

Return type:

float

Marginal inference

Spanning-tree distributions

class torch_two_sample.inference_trees.TreeMarginals(n_vertices, cuda)[source]

Perform marginal inference in models over spanning trees.

The model considered is of the form:

\[p(x) \propto \exp(\sum_{i=1}^m d_i x_i) \nu(x),\]

where \(x\) is a binary random vector with one coordinate per edge, and \(\nu(x)\) is one if \(x\) forms a spanning tree, or zero otherwise.

The numbers \(d_i\) are expected to be given by taking the upper triangular part of the adjacecny matrix. To extract the upper triangular part of a matrix, or to reconstruct them matrix from it, you can use the functions triu() and to_mat().

Parameters:
  • n_vertices (int) – The number of vertices in the graph.
  • cuda (bool) – Should the function work on cuda (on the current device) or cpu.
__call__(d)[source]

Compute the marginals in the model.

Parameters:d (torch.autograd.Variable) – A vector of size n * (n - 1) // 2 containing the \(d_i\).
Returns:The marginal probabilities in a vector of size n * (n - 1) // 2.
Return type:torch.autograd.Variable
to_mat(triu)[source]

Given the upper triangular part, reconstruct the matrix.

Parameters:x (torch.autograd.Variable) – The upper triangular part, should be of size n * (n - 1) / 2.
Returns:The (n, n)-matrix whose upper triangular part filled in with x, and the rest with zeroes
Return type:torch.autograd.Variable
triu(matrix)[source]

Given a matrix, extract its upper triangular part.

Parameters:matrix (torch.autograd.Variable) – A square matrix of size (n, n).
Returns:The upper triangular part of the given matrix, which is of size n * (n - 1) // 2
Return type:torch.autograd.Variable

Cardinality potentials

torch_two_sample.inference_cardinality.inference_cardinality(node_potentials, cardinality_potential)[source]

Perform inference in a graphical model of the form

\[p(x) \propto \exp( \sum_{i=1}^n x_iq_i + f(\sum_{i=1}^n x_i) ),\]

where \(x\) is a binary random variable. The vector \(q\) holds the node potentials, while \(f\) is the so-called cardinality potential.

Parameters:
  • node_potentials (torch.autograd.Variable) – The matrix holding the per-node potentials \(q\) of size (batch_size, n).
  • cardinality_potentials (torch.autograd.Variable) –

    The cardinality potential.

    Should be of size (batch_size, n_potentials). In each row, column i holds the value \(f(i)\). If it happens n_potentials < n + 1, the remaining positions are assumed to be equal to -inf (i.e., are given zero probability).

Bibliography

[DK17](1, 2, 3, 4, 5, 6, 7, 8) J. Djolonga and A. Krause. Learning Implicit Generative Models Using Differentiable Graph Tests. ArXiv e-prints, September 2017. arXiv:1709.01006.
[FR79](1, 2) Jerome H Friedman and Lawrence C Rafsky. Multivariate generalizations of the wald-wolfowitz and smirnov two-sample tests. Annals of Statistics, pages 697–717, 1979.
[FR83](1, 2) Jerome H Friedman and Lawrence C Rafsky. Graph-theoretic measures of multivariate association and prediction. Annals of Statistics, pages 377–391, 1983.
[GBR+12](1, 2) Arthur Gretton, Karsten M Borgwardt, Malte J Rasch, Bernhard Schölkopf, and Alexander Smola. A kernel two-sample test. Journal of Machine Learning Research, 13(Mar):723–773, 2012.
[KB15]Diederik Kingma and Jimmy Ba. Adam: a method for stochastic optimization. In International Conference on Learning Representations (ICLR). 2015.
[SST+12]Kevin Swersky, Ilya Sutskever, Daniel Tarlow, Richard S Zemel, Ruslan R Salakhutdinov, and Ryan P Adams. Cardinality restricted boltzmann machines. In Advances in Neural Information Processing Systems (NIPS), 3293–3301. 2012.
[SzekelyR13](1, 2) Gábor J Székely and Maria L Rizzo. Energy statistics: a class of statistics based on distances. Journal of Statistical Planning and Inference, 143(8):1249–1272, 2013.
[TSZ+12]Daniel Tarlow, Kevin Swersky, Richard S Zemel, Ryan Prescott Adams, and Brendan J Frey. Fast exact inference for recursive cardinality models. Uncertainty in Artificial Intelligence (UAI), 2012.

Indices and tables