Skip to content

API

Import modules and set version.

bijectors

Define the bijectors used in the normalizing flows.

Bijector

Wrapper class for bijector functions

Source code in pzflow/bijectors.py
110
111
112
113
114
115
116
117
118
class Bijector:
    """Wrapper class for bijector functions"""

    def __init__(self, func: Callable) -> None:
        self._func = func
        update_wrapper(self, func)

    def __call__(self, *args, **kwargs) -> Tuple[InitFunction, Bijector_Info]:
        return self._func(*args, **kwargs)

ForwardFunction

Return the output and log_det of the forward bijection on the inputs.

ForwardFunction of a Bijector, originally returned by the InitFunction of the Bijector.

Parameters:

Name Type Description Default
params a Jax pytree

A pytree of bijector parameters. This usually looks like a nested tuple or list of parameters.

required
inputs jnp.ndarray

The data to be transformed by the bijection.

required

Returns:

Name Type Description
outputs jnp.ndarray

Result of the forward bijection applied to the inputs.

log_det jnp.ndarray

The log determinant of the Jacobian evaluated at the inputs.

Source code in pzflow/bijectors.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class ForwardFunction:
    """Return the output and log_det of the forward bijection on the inputs.

    ForwardFunction of a Bijector, originally returned by the
    InitFunction of the Bijector.

    Parameters
    ----------
    params : a Jax pytree
        A pytree of bijector parameters.
        This usually looks like a nested tuple or list of parameters.
    inputs : jnp.ndarray
        The data to be transformed by the bijection.

    Returns
    -------
    outputs : jnp.ndarray
        Result of the forward bijection applied to the inputs.
    log_det : jnp.ndarray
        The log determinant of the Jacobian evaluated at the inputs.
    """

    def __init__(self, func: Callable) -> None:
        self._func = func

    def __call__(
        self, params: Pytree, inputs: jnp.ndarray, **kwargs
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        return self._func(params, inputs, **kwargs)

InitFunction

Initialize the corresponding Bijector.

InitFunction returned by the initialization of a Bijector.

Parameters:

Name Type Description Default
rng jnp.ndarray

A Random Number Key from jax.random.PRNGKey.

required
input_dim int

The input dimension of the bijection.

required

Returns:

Name Type Description
params a Jax pytree

A pytree of bijector parameters. This usually looks like a nested tuple or list of parameters.

forward_fun ForwardFunction

The forward function of the Bijector.

inverse_fun InverseFunction

The inverse function of the Bijector.

Source code in pzflow/bijectors.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
class InitFunction:
    """Initialize the corresponding Bijector.

    InitFunction returned by the initialization of a Bijector.

    Parameters
    ----------
    rng : jnp.ndarray
        A Random Number Key from jax.random.PRNGKey.
    input_dim : int
        The input dimension of the bijection.

    Returns
    -------
    params : a Jax pytree
        A pytree of bijector parameters.
        This usually looks like a nested tuple or list of parameters.
    forward_fun : ForwardFunction
        The forward function of the Bijector.
    inverse_fun : InverseFunction
        The inverse function of the Bijector.
    """

    def __init__(self, func: Callable) -> None:
        self._func = func

    def __call__(
        self, rng: jnp.ndarray, input_dim: int, **kwargs
    ) -> Tuple[Pytree, ForwardFunction, InverseFunction]:
        return self._func(rng, input_dim, **kwargs)

InverseFunction

Return the output and log_det of the inverse bijection on the inputs.

InverseFunction of a Bijector, originally returned by the InitFunction of the Bijector.

Parameters:

Name Type Description Default
params a Jax pytree

A pytree of bijector parameters. This usually looks like a nested tuple or list of parameters.

required
inputs jnp.ndarray

The data to be transformed by the bijection.

required

Returns:

Name Type Description
outputs jnp.ndarray

Result of the inverse bijection applied to the inputs.

log_det jnp.ndarray

The log determinant of the Jacobian evaluated at the inputs.

Source code in pzflow/bijectors.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
class InverseFunction:
    """Return the output and log_det of the inverse bijection on the inputs.

    InverseFunction of a Bijector, originally returned by the
    InitFunction of the Bijector.

    Parameters
    ----------
    params : a Jax pytree
        A pytree of bijector parameters.
        This usually looks like a nested tuple or list of parameters.
    inputs : jnp.ndarray
        The data to be transformed by the bijection.

    Returns
    -------
    outputs : jnp.ndarray
        Result of the inverse bijection applied to the inputs.
    log_det : jnp.ndarray
        The log determinant of the Jacobian evaluated at the inputs.
    """

    def __init__(self, func: Callable) -> None:
        self._func = func

    def __call__(
        self, params: Pytree, inputs: jnp.ndarray, **kwargs
    ) -> Tuple[jnp.ndarray, jnp.ndarray]:
        return self._func(params, inputs, **kwargs)

Chain(inputs)

Bijector that chains multiple InitFunctions into a single InitFunction.

Parameters:

Name Type Description Default
inputs Bijector1(), Bijector2(), ...

A container of Bijector calls to be chained together.

()

Returns:

Type Description
InitFunction

The InitFunction of the total chained Bijector.

Bijector_Info

Tuple('Chain', Tuple(Bijector_Info for each bijection in the chain)) This allows the chain to be recreated later.

Source code in pzflow/bijectors.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
@Bijector
def Chain(
    *inputs: Sequence[Tuple[InitFunction, Bijector_Info]]
) -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that chains multiple InitFunctions into a single InitFunction.

    Parameters
    ----------
    inputs : (Bijector1(), Bijector2(), ...)
        A container of Bijector calls to be chained together.

    Returns
    -------
    InitFunction
        The InitFunction of the total chained Bijector.
    Bijector_Info
        Tuple('Chain', Tuple(Bijector_Info for each bijection in the chain))
        This allows the chain to be recreated later.
    """

    init_funs = tuple(i[0] for i in inputs)
    bijector_info = ("Chain", tuple(i[1] for i in inputs))

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):

        all_params, forward_funs, inverse_funs = [], [], []
        for init_f in init_funs:
            rng, layer_rng = random.split(rng)
            param, forward_f, inverse_f = init_f(layer_rng, input_dim)

            all_params.append(param)
            forward_funs.append(forward_f)
            inverse_funs.append(inverse_f)

        def bijector_chain(params, bijectors, inputs, **kwargs):
            log_dets = jnp.zeros(inputs.shape[0])
            for bijector, param in zip(bijectors, params):
                inputs, log_det = bijector(param, inputs, **kwargs)
                log_dets += log_det
            return inputs, log_dets

        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            return bijector_chain(params, forward_funs, inputs, **kwargs)

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            return bijector_chain(
                params[::-1], inverse_funs[::-1], inputs, **kwargs
            )

        return all_params, forward_fun, inverse_fun

    return init_fun, bijector_info

ColorTransform(ref_idx, mag_idx)

Bijector that calculates photometric colors from magnitudes.

Using ColorTransform restricts and impacts the order of columns in the corresponding normalizing flow. See the notes below for an example.

Parameters:

Name Type Description Default
ref_idx int

The index corresponding to the column of the reference band, which serves as a proxy for overall luminosity.

required
mag_idx arraylike of int

The indices of the magnitude columns from which colors will be calculated.

required

Returns:

Type Description
InitFunction

The InitFunction of the ColorTransform Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

Notes

ColorTransform requires careful management of column order in the bijector. This is best explained with an example:

Assume we have data [redshift, u, g, ellipticity, r, i, z, y, mass] Then ColorTransform(ref_idx=4, mag_idx=[1, 2, 4, 5, 6, 7]) will output [redshift, ellipticity, mass, r, u-g, g-r, r-i, i-z, z-y]

Notice how the non-magnitude columns are aggregated at the front of the array, maintaining their relative order from the original array. These values are then followed by the reference magnitude, and the new colors.

Also notice that the magnitudes indices in mag_idx are assumed to be adjacent colors. E.g. mag_idx=[1, 2, 5, 4, 6, 7] would have produced the colors [u-g, g-i, i-r, r-z, z-y]. You can chain multiple ColorTransforms back-to-back to create colors in a non-adjacent manner.

Source code in pzflow/bijectors.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
@Bijector
def ColorTransform(
    ref_idx: int, mag_idx: int
) -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that calculates photometric colors from magnitudes.

    Using ColorTransform restricts and impacts the order of columns in the
    corresponding normalizing flow. See the notes below for an example.

    Parameters
    ----------
    ref_idx : int
        The index corresponding to the column of the reference band, which
        serves as a proxy for overall luminosity.
    mag_idx : arraylike of int
        The indices of the magnitude columns from which colors will be calculated.

    Returns
    -------
    InitFunction
        The InitFunction of the ColorTransform Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.

    Notes
    -----
    ColorTransform requires careful management of column order in the bijector.
    This is best explained with an example:

    Assume we have data
    [redshift, u, g, ellipticity, r, i, z, y, mass]
    Then
    ColorTransform(ref_idx=4, mag_idx=[1, 2, 4, 5, 6, 7])
    will output
    [redshift, ellipticity, mass, r, u-g, g-r, r-i, i-z, z-y]

    Notice how the non-magnitude columns are aggregated at the front of the
    array, maintaining their relative order from the original array.
    These values are then followed by the reference magnitude, and the new colors.

    Also notice that the magnitudes indices in mag_idx are assumed to be
    adjacent colors. E.g. mag_idx=[1, 2, 5, 4, 6, 7] would have produced
    the colors [u-g, g-i, i-r, r-z, z-y]. You can chain multiple ColorTransforms
    back-to-back to create colors in a non-adjacent manner.
    """

    # validate parameters
    if ref_idx <= 0:
        raise ValueError("ref_idx must be a positive integer.")
    if not isinstance(ref_idx, int):
        raise ValueError("ref_idx must be an integer.")
    if ref_idx not in mag_idx:
        raise ValueError("ref_idx must be in mag_idx.")

    bijector_info = ("ColorTransform", (ref_idx, mag_idx))

    # convert mag_idx to an array
    mag_idx = jnp.array(mag_idx)

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):

        # array of all the indices
        all_idx = jnp.arange(input_dim)
        # indices for columns to stick at the front
        front_idx = jnp.setdiff1d(all_idx, mag_idx)
        # the index corresponding to the first magnitude
        mag0_idx = len(front_idx)

        # the new column order
        new_idx = jnp.concatenate((front_idx, mag_idx))
        # the new column for the reference magnitude
        new_ref = jnp.where(new_idx == ref_idx)[0][0]

        # define a convenience function for the forward_fun below
        # if the first magnitude is the reference mag, do nothing
        if ref_idx == mag_idx[0]:

            def mag0(outputs):
                return outputs

        # if the first magnitude is not the reference mag,
        # then we need to calculate the first magnitude (mag[0])
        else:

            def mag0(outputs):
                return outputs.at[:, mag0_idx].set(
                    outputs[:, mag0_idx] + outputs[:, new_ref],
                    indices_are_sorted=True,
                    unique_indices=True,
                )

        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            # re-order columns and calculate colors
            outputs = jnp.hstack(
                (
                    inputs[:, front_idx],  # other values
                    inputs[:, ref_idx, None],  # ref mag
                    -jnp.diff(inputs[:, mag_idx]),  # colors
                )
            )
            # determinant of Jacobian is zero
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            # convert all colors to be in terms of the first magnitude, mag[0]
            outputs = jnp.hstack(
                (
                    inputs[:, 0:mag0_idx],  # other values unchanged
                    inputs[:, mag0_idx, None],  # reference mag unchanged
                    jnp.cumsum(
                        inputs[:, mag0_idx + 1 :], axis=-1
                    ),  # all colors mag[i-1] - mag[i] --> mag[0] - mag[i]
                )
            )
            # calculate mag[0]
            outputs = mag0(outputs)
            # mag[i] = mag[0] - (mag[0] - mag[i])
            outputs = outputs.at[:, mag0_idx + 1 :].set(
                outputs[:, mag0_idx, None] - outputs[:, mag0_idx + 1 :],
                indices_are_sorted=True,
                unique_indices=True,
            )
            # return to original ordering
            outputs = outputs[:, jnp.argsort(new_idx)]
            # determinant of Jacobian is zero
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        return (), forward_fun, inverse_fun

    return init_fun, bijector_info

InvSoftplus(column_idx, sharpness=1)

Bijector that applies inverse softplus to the specified column(s).

Applying the inverse softplus ensures that samples from that column will always be non-negative. This is because samples are the output of the inverse bijection -- so samples will have a softplus applied to them.

Parameters:

Name Type Description Default
column_idx int

An index or iterable of indices corresponding to the column(s) you wish to be transformed.

required
sharpness float

The sharpness(es) of the softplus transformation. If more than one is provided, the list of sharpnesses must be of the same length as column_idx.

1

Returns:

Type Description
InitFunction

The InitFunction of the Softplus Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

Source code in pzflow/bijectors.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
@Bijector
def InvSoftplus(
    column_idx: int, sharpness: float = 1
) -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that applies inverse softplus to the specified column(s).

    Applying the inverse softplus ensures that samples from that column will
    always be non-negative. This is because samples are the output of the
    inverse bijection -- so samples will have a softplus applied to them.

    Parameters
    ----------
    column_idx : int
        An index or iterable of indices corresponding to the column(s)
        you wish to be transformed.
    sharpness : float; default=1
        The sharpness(es) of the softplus transformation. If more than one
        is provided, the list of sharpnesses must be of the same length as
        column_idx.

    Returns
    -------
    InitFunction
        The InitFunction of the Softplus Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.
    """

    idx = jnp.atleast_1d(column_idx)
    k = jnp.atleast_1d(sharpness)
    if len(idx) != len(k) and len(k) != 1:
        raise ValueError(
            "Please provide either a single sharpness or one for each column index."
        )

    bijector_info = ("InvSoftplus", (column_idx, sharpness))

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):
        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            outputs = inputs.at[:, idx].set(
                jnp.log(-1 + jnp.exp(k * inputs[:, idx])) / k,
            )
            log_det = jnp.log(1 + jnp.exp(-k * outputs[:, idx])).sum(axis=1)
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            outputs = inputs.at[:, idx].set(
                jnp.log(1 + jnp.exp(k * inputs[:, idx])) / k,
            )
            log_det = -jnp.log(1 + jnp.exp(-k * inputs[:, idx])).sum(axis=1)
            return outputs, log_det

        return (), forward_fun, inverse_fun

    return init_fun, bijector_info

NeuralSplineCoupling(K=16, B=5, hidden_layers=2, hidden_dim=128, transformed_dim=None, n_conditions=0, periodic=False)

A coupling layer bijection with rational quadratic splines.

This Bijector is a Coupling Layer [1,2], and as such only transforms the second half of input dimensions (or the last N dimensions, where N = transformed_dim). In order to transform all of the dimensions, you need multiple Couplings interspersed with Bijectors that change the order of inputs dimensions, e.g., Reverse, Shuffle, Roll, etc.

NeuralSplineCoupling uses piecewise rational quadratic splines, as developed in [3].

If periodic=True, then this is a Circular Spline as described in [4].

Parameters:

Name Type Description Default
K int

Number of bins in the spline (the number of knots is K+1).

16
B float

Range of the splines. If periodic=False, outside of (-B,B), the transformation is just the identity. If periodic=True, the input is mapped into the appropriate location in the range (-B,B).

5
hidden_layers int

The number of hidden layers in the neural network used to calculate the positions and derivatives of the spline knots.

2
hidden_dim int

The width of the hidden layers in the neural network used to calculate the positions and derivatives of the spline knots.

128
transformed_dim int

The number of dimensions transformed by the splines. Default is ceiling(input_dim /2).

None
n_conditions int

The number of variables to condition the bijection on.

0
periodic bool

Whether to make this a periodic, Circular Spline [4].

False

Returns:

Type Description
InitFunction

The InitFunction of the NeuralSplineCoupling Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

References

[1] Laurent Dinh, David Krueger, Yoshua Bengio. NICE: Non-linear Independent Components Estimation. arXiv: 1605.08803, 2015. http://arxiv.org/abs/1605.08803 [2] Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. Density Estimation Using Real NVP. arXiv: 1605.08803, 2017. http://arxiv.org/abs/1605.08803 [3] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural Spline Flows. arXiv:1906.04032, 2019. https://arxiv.org/abs/1906.04032 [4] Rezende, Danilo Jimenez et al. Normalizing Flows on Tori and Spheres. arxiv:2002.02428, 2020 http://arxiv.org/abs/2002.02428

Source code in pzflow/bijectors.py
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
@Bijector
def NeuralSplineCoupling(
    K: int = 16,
    B: float = 5,
    hidden_layers: int = 2,
    hidden_dim: int = 128,
    transformed_dim: int = None,
    n_conditions: int = 0,
    periodic: bool = False,
) -> Tuple[InitFunction, Bijector_Info]:
    """A coupling layer bijection with rational quadratic splines.

    This Bijector is a Coupling Layer [1,2], and as such only transforms
    the second half of input dimensions (or the last N dimensions, where
    N = transformed_dim). In order to transform all of the dimensions,
    you need multiple Couplings interspersed with Bijectors that change
    the order of inputs dimensions, e.g., Reverse, Shuffle, Roll, etc.

    NeuralSplineCoupling uses piecewise rational quadratic splines,
    as developed in [3].

    If periodic=True, then this is a Circular Spline as described in [4].

    Parameters
    ----------
    K : int; default=16
        Number of bins in the spline (the number of knots is K+1).
    B : float; default=5
        Range of the splines.
        If periodic=False, outside of (-B,B), the transformation is just
        the identity. If periodic=True, the input is mapped into the
        appropriate location in the range (-B,B).
    hidden_layers : int; default=2
        The number of hidden layers in the neural network used to calculate
        the positions and derivatives of the spline knots.
    hidden_dim : int; default=128
        The width of the hidden layers in the neural network used to
        calculate the positions and derivatives of the spline knots.
    transformed_dim : int; optional
        The number of dimensions transformed by the splines.
        Default is ceiling(input_dim /2).
    n_conditions : int; default=0
        The number of variables to condition the bijection on.
    periodic : bool; default=False
        Whether to make this a periodic, Circular Spline [4].

    Returns
    -------
    InitFunction
        The InitFunction of the NeuralSplineCoupling Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.

    References
    ----------
    [1] Laurent Dinh, David Krueger, Yoshua Bengio. NICE: Non-linear
        Independent Components Estimation. arXiv: 1605.08803, 2015.
        http://arxiv.org/abs/1605.08803
    [2] Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio.
        Density Estimation Using Real NVP. arXiv: 1605.08803, 2017.
        http://arxiv.org/abs/1605.08803
    [3] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios.
        Neural Spline Flows. arXiv:1906.04032, 2019.
        https://arxiv.org/abs/1906.04032
    [4] Rezende, Danilo Jimenez et al.
        Normalizing Flows on Tori and Spheres. arxiv:2002.02428, 2020
        http://arxiv.org/abs/2002.02428
    """

    if not isinstance(periodic, bool):
        raise ValueError("`periodic` must be True or False.")

    bijector_info = (
        "NeuralSplineCoupling",
        (
            K,
            B,
            hidden_layers,
            hidden_dim,
            transformed_dim,
            n_conditions,
            periodic,
        ),
    )

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):

        if transformed_dim is None:
            upper_dim = input_dim // 2  # variables that determine NN params
            lower_dim = (
                input_dim - upper_dim
            )  # variables transformed by the NN
        else:
            upper_dim = input_dim - transformed_dim
            lower_dim = transformed_dim

        # create the neural network that will take in the upper dimensions and
        # will return the spline parameters to transform the lower dimensions
        network_init_fun, network_apply_fun = DenseReluNetwork(
            (3 * K - 1 + int(periodic)) * lower_dim, hidden_layers, hidden_dim
        )
        _, network_params = network_init_fun(rng, (upper_dim + n_conditions,))

        # calculate spline parameters as a function of the upper variables
        def spline_params(params, upper, conditions):
            key = jnp.hstack((upper, conditions))[
                :, : upper_dim + n_conditions
            ]
            outputs = network_apply_fun(params, key)
            outputs = jnp.reshape(
                outputs, [-1, lower_dim, 3 * K - 1 + int(periodic)]
            )
            W, H, D = jnp.split(outputs, [K, 2 * K], axis=2)
            W = 2 * B * softmax(W)
            H = 2 * B * softmax(H)
            D = softplus(D)
            return W, H, D

        @ForwardFunction
        def forward_fun(params, inputs, conditions, **kwargs):
            # lower dimensions are transformed as function of upper dimensions
            upper, lower = inputs[:, :upper_dim], inputs[:, upper_dim:]
            # widths, heights, derivatives = function(upper dimensions)
            W, H, D = spline_params(params, upper, conditions)
            # transform the lower dimensions with the Rational Quadratic Spline
            lower, log_det = RationalQuadraticSpline(
                lower, W, H, D, B, periodic, inverse=False
            )
            outputs = jnp.hstack((upper, lower))
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, conditions, **kwargs):
            # lower dimensions are transformed as function of upper dimensions
            upper, lower = inputs[:, :upper_dim], inputs[:, upper_dim:]
            # widths, heights, derivatives = function(upper dimensions)
            W, H, D = spline_params(params, upper, conditions)
            # transform the lower dimensions with the Rational Quadratic Spline
            lower, log_det = RationalQuadraticSpline(
                lower, W, H, D, B, periodic, inverse=True
            )
            outputs = jnp.hstack((upper, lower))
            return outputs, log_det

        return network_params, forward_fun, inverse_fun

    return init_fun, bijector_info

Reverse()

Bijector that reverses the order of inputs.

Returns:

Type Description
InitFunction

The InitFunction of the the Reverse Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

Source code in pzflow/bijectors.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
@Bijector
def Reverse() -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that reverses the order of inputs.

    Returns
    -------
    InitFunction
        The InitFunction of the the Reverse Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.
    """

    bijector_info = ("Reverse", ())

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):
        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            outputs = inputs[:, ::-1]
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            outputs = inputs[:, ::-1]
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        return (), forward_fun, inverse_fun

    return init_fun, bijector_info

Roll(shift=1)

Bijector that rolls inputs along their last column using jnp.roll.

Parameters:

Name Type Description Default
shift int

The number of places to roll.

1

Returns:

Type Description
InitFunction

The InitFunction of the the Roll Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

Source code in pzflow/bijectors.py
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
@Bijector
def Roll(shift: int = 1) -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that rolls inputs along their last column using jnp.roll.

    Parameters
    ----------
    shift : int; default=1
        The number of places to roll.

    Returns
    -------
    InitFunction
        The InitFunction of the the Roll Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.
    """

    if not isinstance(shift, int):
        raise ValueError("shift must be an integer.")

    bijector_info = ("Roll", (shift,))

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):
        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            outputs = jnp.roll(inputs, shift=shift, axis=-1)
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            outputs = jnp.roll(inputs, shift=-shift, axis=-1)
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        return (), forward_fun, inverse_fun

    return init_fun, bijector_info

RollingSplineCoupling(nlayers, shift=1, K=16, B=5, hidden_layers=2, hidden_dim=128, transformed_dim=None, n_conditions=0, periodic=False)

Bijector that alternates NeuralSplineCouplings and Roll bijections.

Parameters:

Name Type Description Default
nlayers int

The number of (NeuralSplineCoupling(), Roll()) couplets in the chain.

required
shift int

How far the inputs are shifted on each Roll().

1
K int

Number of bins in the RollingSplineCoupling.

16
B float

Range of the splines in the RollingSplineCoupling. If periodic=False, outside of (-B,B), the transformation is just the identity. If periodic=True, the input is mapped into the appropriate location in the range (-B,B).

5
hidden_layers int

The number of hidden layers in the neural network used to calculate the bins and derivatives in the RollingSplineCoupling.

2
hidden_dim int

The width of the hidden layers in the neural network used to calculate the bins and derivatives in the RollingSplineCoupling.

128
transformed_dim int

The number of dimensions transformed by the splines. Default is ceiling(input_dim /2).

None
n_conditions int

The number of variables to condition the bijection on.

0
periodic bool

Whether to make this a periodic, Circular Spline

False

Returns:

Type Description
InitFunction

The InitFunction of the RollingSplineCoupling Bijector.

Bijector_Info

Nested tuple of the Bijector name and input parameters. This allows it to be recreated later.

Source code in pzflow/bijectors.py
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
@Bijector
def RollingSplineCoupling(
    nlayers: int,
    shift: int = 1,
    K: int = 16,
    B: float = 5,
    hidden_layers: int = 2,
    hidden_dim: int = 128,
    transformed_dim: int = None,
    n_conditions: int = 0,
    periodic: bool = False,
) -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that alternates NeuralSplineCouplings and Roll bijections.

    Parameters
    ----------
    nlayers : int
        The number of (NeuralSplineCoupling(), Roll()) couplets in the chain.
    shift : int
        How far the inputs are shifted on each Roll().
    K : int; default=16
        Number of bins in the RollingSplineCoupling.
    B : float; default=5
        Range of the splines in the RollingSplineCoupling.
        If periodic=False, outside of (-B,B), the transformation is just
        the identity. If periodic=True, the input is mapped into the
        appropriate location in the range (-B,B).
    hidden_layers : int; default=2
        The number of hidden layers in the neural network used to calculate
        the bins and derivatives in the RollingSplineCoupling.
    hidden_dim : int; default=128
        The width of the hidden layers in the neural network used to
        calculate the bins and derivatives in the RollingSplineCoupling.
    transformed_dim : int; optional
        The number of dimensions transformed by the splines.
        Default is ceiling(input_dim /2).
    n_conditions : int; default=0
        The number of variables to condition the bijection on.
    periodic : bool; default=False
        Whether to make this a periodic, Circular Spline

    Returns
    -------
    InitFunction
        The InitFunction of the RollingSplineCoupling Bijector.
    Bijector_Info
        Nested tuple of the Bijector name and input parameters. This allows
        it to be recreated later.

    """
    return Chain(
        *(
            NeuralSplineCoupling(
                K=K,
                B=B,
                hidden_layers=hidden_layers,
                hidden_dim=hidden_dim,
                transformed_dim=transformed_dim,
                n_conditions=n_conditions,
                periodic=periodic,
            ),
            Roll(shift),
        )
        * nlayers
    )

Scale(scale)

Bijector that multiplies inputs by a scalar.

Parameters:

Name Type Description Default
scale float

Factor by which to scale inputs.

required

Returns:

Type Description
InitFunction

The InitFunction of the the Scale Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

Source code in pzflow/bijectors.py
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
@Bijector
def Scale(scale: float) -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that multiplies inputs by a scalar.

    Parameters
    ----------
    scale : float
        Factor by which to scale inputs.

    Returns
    -------
    InitFunction
        The InitFunction of the the Scale Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.
    """

    if isinstance(scale, jnp.ndarray):
        if scale.dtype != jnp.float32:
            raise ValueError("scale must be a float or array of floats.")
    elif not isinstance(scale, float):
        raise ValueError("scale must be a float or array of floats.")

    bijector_info = ("Scale", (scale,))

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):
        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            outputs = scale * inputs
            log_det = jnp.log(scale ** inputs.shape[-1]) * jnp.ones(
                inputs.shape[0]
            )
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            outputs = 1 / scale * inputs
            log_det = -jnp.log(scale ** inputs.shape[-1]) * jnp.ones(
                inputs.shape[0]
            )
            return outputs, log_det

        return (), forward_fun, inverse_fun

    return init_fun, bijector_info

ShiftBounds(min, max, B=5)

Bijector shifts the bounds of inputs so the lie in the range (-B, B).

Parameters:

Name Type Description Default
min float

The minimum of the input range.

required
min float

The maximum of the input range.

required
B float

The extent of the output bounds, which will be (-B, B).

5

Returns:

Type Description
InitFunction

The InitFunction of the ShiftBounds Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

Source code in pzflow/bijectors.py
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
@Bijector
def ShiftBounds(
    min: float, max: float, B: float = 5
) -> Tuple[InitFunction, Bijector_Info]:
    """Bijector shifts the bounds of inputs so the lie in the range (-B, B).

    Parameters
    ----------
    min : float
        The minimum of the input range.
    min : float
        The maximum of the input range.
    B : float; default=5
        The extent of the output bounds, which will be (-B, B).

    Returns
    -------
    InitFunction
        The InitFunction of the ShiftBounds Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.
    """

    min = jnp.atleast_1d(min)
    max = jnp.atleast_1d(max)
    if len(min) != len(max):
        raise ValueError(
            "Lengths of min and max do not match. "
            + "Please provide either a single min and max, "
            + "or a min and max for each dimension."
        )
    if (min > max).any():
        raise ValueError("All mins must be less than maxes.")

    bijector_info = ("ShiftBounds", (min, max, B))

    mean = (max + min) / 2
    half_range = (max - min) / 2

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):
        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            outputs = B * (inputs - mean) / half_range
            log_det = jnp.log(jnp.prod(B / half_range)) * jnp.ones(
                inputs.shape[0]
            )
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            outputs = inputs * half_range / B + mean
            log_det = jnp.log(jnp.prod(half_range / B)) * jnp.ones(
                inputs.shape[0]
            )
            return outputs, log_det

        return (), forward_fun, inverse_fun

    return init_fun, bijector_info

Shuffle()

Bijector that randomly permutes inputs.

Returns:

Type Description
InitFunction

The InitFunction of the Shuffle Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

Source code in pzflow/bijectors.py
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
@Bijector
def Shuffle() -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that randomly permutes inputs.

    Returns
    -------
    InitFunction
        The InitFunction of the Shuffle Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.
    """

    bijector_info = ("Shuffle", ())

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):

        perm = random.permutation(rng, jnp.arange(input_dim))
        inv_perm = jnp.argsort(perm)

        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            outputs = inputs[:, perm]
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            outputs = inputs[:, inv_perm]
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        return (), forward_fun, inverse_fun

    return init_fun, bijector_info

StandardScaler(means, stds)

Bijector that applies standard scaling to each input.

Each input dimension i has an associated mean u_i and standard dev s_i. Each input is rescaled as (input[i] - u_i)/s_i, so that each input dimension has mean zero and unit variance.

Parameters:

Name Type Description Default
means jnp.ndarray

The mean of each column.

required
stds jnp.ndarray

The standard deviation of each column.

required

Returns:

Type Description
InitFunction

The InitFunction of the StandardScaler Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

Source code in pzflow/bijectors.py
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
@Bijector
def StandardScaler(
    means: jnp.array, stds: jnp.array
) -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that applies standard scaling to each input.

    Each input dimension i has an associated mean u_i and standard dev s_i.
    Each input is rescaled as (input[i] - u_i)/s_i, so that each input dimension
    has mean zero and unit variance.

    Parameters
    ----------
    means : jnp.ndarray
        The mean of each column.
    stds : jnp.ndarray
        The standard deviation of each column.

    Returns
    -------
    InitFunction
        The InitFunction of the StandardScaler Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.
    """

    bijector_info = ("StandardScaler", (means, stds))

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):
        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            outputs = (inputs - means) / stds
            log_det = jnp.log(1 / jnp.prod(stds)) * jnp.ones(inputs.shape[0])
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            outputs = inputs * stds + means
            log_det = jnp.log(jnp.prod(stds)) * jnp.ones(inputs.shape[0])
            return outputs, log_det

        return (), forward_fun, inverse_fun

    return init_fun, bijector_info

UniformDequantizer(column_idx)

Bijector that dequantizes discrete variables with uniform noise.

Dequantizers are necessary for modeling discrete values with a flow. Note that this isn't technically a bijector.

Parameters:

Name Type Description Default
column_idx int

An index or iterable of indices corresponding to the column(s) with discrete values.

required

Returns:

Type Description
InitFunction

The InitFunction of the UniformDequantizer Bijector.

Bijector_Info

Tuple of the Bijector name and the input parameters. This allows it to be recreated later.

Source code in pzflow/bijectors.py
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
@Bijector
def UniformDequantizer(column_idx: int) -> Tuple[InitFunction, Bijector_Info]:
    """Bijector that dequantizes discrete variables with uniform noise.

    Dequantizers are necessary for modeling discrete values with a flow.
    Note that this isn't technically a bijector.

    Parameters
    ----------
    column_idx : int
        An index or iterable of indices corresponding to the column(s) with
        discrete values.

    Returns
    -------
    InitFunction
        The InitFunction of the UniformDequantizer Bijector.
    Bijector_Info
        Tuple of the Bijector name and the input parameters.
        This allows it to be recreated later.
    """

    bijector_info = ("UniformDequantizer", (column_idx,))
    column_idx = jnp.array(column_idx)

    @InitFunction
    def init_fun(rng, input_dim, **kwargs):
        @ForwardFunction
        def forward_fun(params, inputs, **kwargs):
            u = random.uniform(
                random.PRNGKey(0), shape=inputs[:, column_idx].shape
            )
            outputs = inputs.astype(float)
            outputs.at[:, column_idx].set(outputs[:, column_idx] + u)
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        @InverseFunction
        def inverse_fun(params, inputs, **kwargs):
            outputs = inputs.at[:, column_idx].set(
                jnp.floor(inputs[:, column_idx])
            )
            log_det = jnp.zeros(inputs.shape[0])
            return outputs, log_det

        return (), forward_fun, inverse_fun

    return init_fun, bijector_info

distributions

Define the latent distributions used in the normalizing flows.

CentBeta

Bases: LatentDist

A centered Beta distribution.

This distribution is just a regular Beta distribution, scaled and shifted to have support on the domain [-B, B] in each dimension.

Alpha and beta parameters for each dimension are learned during training.

Source code in pzflow/distributions.py
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class CentBeta(LatentDist):
    """A centered Beta distribution.

    This distribution is just a regular Beta distribution, scaled and shifted
    to have support on the domain [-B, B] in each dimension.

    Alpha and beta parameters for each dimension are learned during training.
    """

    def __init__(self, input_dim: int, B: float = 5) -> None:
        """
        Parameters
        ----------
        input_dim : int
            The dimension of the distribution.
        B : float; default=5
            The distribution has support (-B, B) along each dimension.
        """
        self.input_dim = input_dim
        self.B = B

        # save dist info
        self._params = tuple([(0.0, 0.0) for i in range(input_dim)])
        self.info = ("CentBeta", (input_dim, B))

    def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
        """Calculates log probability density of inputs.

        Parameters
        ----------
        params : a Jax pytree
            Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta)
            for the Nth dimension.
        inputs : jnp.ndarray
            Input data for which log probability density is calculated.

        Returns
        -------
        jnp.ndarray
            Device array of shape (inputs.shape[0],).
        """
        log_prob = jnp.hstack(
            [
                beta.logpdf(
                    inputs[:, i],
                    a=jnp.exp(params[i][0]),
                    b=jnp.exp(params[i][1]),
                    loc=-self.B,
                    scale=2 * self.B,
                ).reshape(-1, 1)
                for i in range(self.input_dim)
            ]
        ).sum(axis=1)

        return log_prob

    def sample(
        self, params: Pytree, nsamples: int, seed: int = None
    ) -> jnp.ndarray:
        """Returns samples from the distribution.

        Parameters
        ----------
        params : a Jax pytree
            Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta)
            for the Nth dimension.
        nsamples : int
            The number of samples to be returned.
        seed : int; optional
            Sets the random seed for the samples.

        Returns
        -------
        jnp.ndarray
            Device array of shape (nsamples, self.input_dim).
        """
        seed = np.random.randint(1e18) if seed is None else seed
        seeds = random.split(random.PRNGKey(seed), self.input_dim)
        samples = jnp.hstack(
            [
                random.beta(
                    seeds[i],
                    jnp.exp(params[i][0]),
                    jnp.exp(params[i][1]),
                    shape=(nsamples, 1),
                )
                for i in range(self.input_dim)
            ]
        )
        return 2 * self.B * (samples - 0.5)

__init__(input_dim, B=5)

Parameters:

Name Type Description Default
input_dim int

The dimension of the distribution.

required
B float

The distribution has support (-B, B) along each dimension.

5
Source code in pzflow/distributions.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(self, input_dim: int, B: float = 5) -> None:
    """
    Parameters
    ----------
    input_dim : int
        The dimension of the distribution.
    B : float; default=5
        The distribution has support (-B, B) along each dimension.
    """
    self.input_dim = input_dim
    self.B = B

    # save dist info
    self._params = tuple([(0.0, 0.0) for i in range(input_dim)])
    self.info = ("CentBeta", (input_dim, B))

log_prob(params, inputs)

Calculates log probability density of inputs.

Parameters:

Name Type Description Default
params a Jax pytree

Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta) for the Nth dimension.

required
inputs jnp.ndarray

Input data for which log probability density is calculated.

required

Returns:

Type Description
jnp.ndarray

Device array of shape (inputs.shape[0],).

Source code in pzflow/distributions.py
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
    """Calculates log probability density of inputs.

    Parameters
    ----------
    params : a Jax pytree
        Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta)
        for the Nth dimension.
    inputs : jnp.ndarray
        Input data for which log probability density is calculated.

    Returns
    -------
    jnp.ndarray
        Device array of shape (inputs.shape[0],).
    """
    log_prob = jnp.hstack(
        [
            beta.logpdf(
                inputs[:, i],
                a=jnp.exp(params[i][0]),
                b=jnp.exp(params[i][1]),
                loc=-self.B,
                scale=2 * self.B,
            ).reshape(-1, 1)
            for i in range(self.input_dim)
        ]
    ).sum(axis=1)

    return log_prob

sample(params, nsamples, seed=None)

Returns samples from the distribution.

Parameters:

Name Type Description Default
params a Jax pytree

Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta) for the Nth dimension.

required
nsamples int

The number of samples to be returned.

required
seed int

Sets the random seed for the samples.

None

Returns:

Type Description
jnp.ndarray

Device array of shape (nsamples, self.input_dim).

Source code in pzflow/distributions.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def sample(
    self, params: Pytree, nsamples: int, seed: int = None
) -> jnp.ndarray:
    """Returns samples from the distribution.

    Parameters
    ----------
    params : a Jax pytree
        Tuple of ((a1, b1), (a2, b2), ...) where aN,bN are log(alpha),log(beta)
        for the Nth dimension.
    nsamples : int
        The number of samples to be returned.
    seed : int; optional
        Sets the random seed for the samples.

    Returns
    -------
    jnp.ndarray
        Device array of shape (nsamples, self.input_dim).
    """
    seed = np.random.randint(1e18) if seed is None else seed
    seeds = random.split(random.PRNGKey(seed), self.input_dim)
    samples = jnp.hstack(
        [
            random.beta(
                seeds[i],
                jnp.exp(params[i][0]),
                jnp.exp(params[i][1]),
                shape=(nsamples, 1),
            )
            for i in range(self.input_dim)
        ]
    )
    return 2 * self.B * (samples - 0.5)

CentBeta13

Bases: LatentDist

A centered Beta distribution with alpha, beta = 13.

This distribution is just a regular Beta distribution, scaled and shifted to have support on the domain [-B, B] in each dimension.

Alpha, beta = 13 means that the distribution looks like a Gaussian distribution, but with hard cutoffs at +/- B.

Source code in pzflow/distributions.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class CentBeta13(LatentDist):
    """A centered Beta distribution with alpha, beta = 13.

    This distribution is just a regular Beta distribution, scaled and shifted
    to have support on the domain [-B, B] in each dimension.

    Alpha, beta = 13 means that the distribution looks like a Gaussian
    distribution, but with hard cutoffs at +/- B.
    """

    def __init__(self, input_dim: int, B: float = 5) -> None:
        """
        Parameters
        ----------
        input_dim : int
            The dimension of the distribution.
        B : float; default=5
            The distribution has support (-B, B) along each dimension.
        """
        self.input_dim = input_dim
        self.B = B

        # save dist info
        self._params = tuple([(0.0, 0.0) for i in range(input_dim)])
        self.info = ("CentBeta13", (input_dim, B))
        self.a = 13
        self.b = 13

    def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
        """Calculates log probability density of inputs.

        Parameters
        ----------
        params : a Jax pytree
            Empty pytree -- this distribution doesn't have learnable parameters.
            This parameter is present to ensure a consistent interface.
        inputs : jnp.ndarray
            Input data for which log probability density is calculated.

        Returns
        -------
        jnp.ndarray
            Device array of shape (inputs.shape[0],).
        """
        log_prob = jnp.hstack(
            [
                beta.logpdf(
                    inputs[:, i],
                    a=self.a,
                    b=self.b,
                    loc=-self.B,
                    scale=2 * self.B,
                ).reshape(-1, 1)
                for i in range(self.input_dim)
            ]
        ).sum(axis=1)

        return log_prob

    def sample(
        self, params: Pytree, nsamples: int, seed: int = None
    ) -> jnp.ndarray:
        """Returns samples from the distribution.

        Parameters
        ----------
        params : a Jax pytree
            Empty pytree -- this distribution doesn't have learnable parameters.
            This parameter is present to ensure a consistent interface.
        nsamples : int
            The number of samples to be returned.
        seed : int; optional
            Sets the random seed for the samples.

        Returns
        -------
        jnp.ndarray
            Device array of shape (nsamples, self.input_dim).
        """
        seed = np.random.randint(1e18) if seed is None else seed
        seeds = random.split(random.PRNGKey(seed), self.input_dim)
        samples = jnp.hstack(
            [
                random.beta(
                    seeds[i],
                    self.a,
                    self.b,
                    shape=(nsamples, 1),
                )
                for i in range(self.input_dim)
            ]
        )
        return 2 * self.B * (samples - 0.5)

__init__(input_dim, B=5)

Parameters:

Name Type Description Default
input_dim int

The dimension of the distribution.

required
B float

The distribution has support (-B, B) along each dimension.

5
Source code in pzflow/distributions.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def __init__(self, input_dim: int, B: float = 5) -> None:
    """
    Parameters
    ----------
    input_dim : int
        The dimension of the distribution.
    B : float; default=5
        The distribution has support (-B, B) along each dimension.
    """
    self.input_dim = input_dim
    self.B = B

    # save dist info
    self._params = tuple([(0.0, 0.0) for i in range(input_dim)])
    self.info = ("CentBeta13", (input_dim, B))
    self.a = 13
    self.b = 13

log_prob(params, inputs)

Calculates log probability density of inputs.

Parameters:

Name Type Description Default
params a Jax pytree

Empty pytree -- this distribution doesn't have learnable parameters. This parameter is present to ensure a consistent interface.

required
inputs jnp.ndarray

Input data for which log probability density is calculated.

required

Returns:

Type Description
jnp.ndarray

Device array of shape (inputs.shape[0],).

Source code in pzflow/distributions.py
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
    """Calculates log probability density of inputs.

    Parameters
    ----------
    params : a Jax pytree
        Empty pytree -- this distribution doesn't have learnable parameters.
        This parameter is present to ensure a consistent interface.
    inputs : jnp.ndarray
        Input data for which log probability density is calculated.

    Returns
    -------
    jnp.ndarray
        Device array of shape (inputs.shape[0],).
    """
    log_prob = jnp.hstack(
        [
            beta.logpdf(
                inputs[:, i],
                a=self.a,
                b=self.b,
                loc=-self.B,
                scale=2 * self.B,
            ).reshape(-1, 1)
            for i in range(self.input_dim)
        ]
    ).sum(axis=1)

    return log_prob

sample(params, nsamples, seed=None)

Returns samples from the distribution.

Parameters:

Name Type Description Default
params a Jax pytree

Empty pytree -- this distribution doesn't have learnable parameters. This parameter is present to ensure a consistent interface.

required
nsamples int

The number of samples to be returned.

required
seed int

Sets the random seed for the samples.

None

Returns:

Type Description
jnp.ndarray

Device array of shape (nsamples, self.input_dim).

Source code in pzflow/distributions.py
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
def sample(
    self, params: Pytree, nsamples: int, seed: int = None
) -> jnp.ndarray:
    """Returns samples from the distribution.

    Parameters
    ----------
    params : a Jax pytree
        Empty pytree -- this distribution doesn't have learnable parameters.
        This parameter is present to ensure a consistent interface.
    nsamples : int
        The number of samples to be returned.
    seed : int; optional
        Sets the random seed for the samples.

    Returns
    -------
    jnp.ndarray
        Device array of shape (nsamples, self.input_dim).
    """
    seed = np.random.randint(1e18) if seed is None else seed
    seeds = random.split(random.PRNGKey(seed), self.input_dim)
    samples = jnp.hstack(
        [
            random.beta(
                seeds[i],
                self.a,
                self.b,
                shape=(nsamples, 1),
            )
            for i in range(self.input_dim)
        ]
    )
    return 2 * self.B * (samples - 0.5)

Joint

Bases: LatentDist

A joint distribution built from other distributions.

Note that each of the other distributions already have support for multiple dimensions. This is only useful if you want to combine different distributions for different dimensions, e.g. if your first dimension has a Uniform latent space and the second dimension has a CentBeta latent space.

Source code in pzflow/distributions.py
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
class Joint(LatentDist):
    """A joint distribution built from other distributions.

    Note that each of the other distributions already have support for
    multiple dimensions. This is only useful if you want to combine
    different distributions for different dimensions, e.g. if your first
    dimension has a Uniform latent space and the second dimension has a
    CentBeta latent space.
    """

    def __init__(self, *inputs: Union[LatentDist, tuple]) -> None:
        """
        Parameters
        ----------
        inputs: LatentDist or tuple
            The latent distributions to join together.
        """

        # if Joint info provided, use that for setup
        if inputs[0] == "Joint info":
            self.dists = [globals()[dist[0]](*dist[1]) for dist in inputs[1]]
        # otherwise, assume it's a list of distributions
        else:
            self.dists = inputs

        # save info
        self._params = [dist._params for dist in self.dists]
        self.input_dim = sum([dist.input_dim for dist in self.dists])
        self.info = (
            "Joint",
            ("Joint info", [dist.info for dist in self.dists]),
        )

        # save the indices at which inputs will be split for log_prob
        # they must be concretely saved ahead-of-time so that jax trace
        # works properly when jitting
        self._splits = jnp.cumsum(
            jnp.array([dist.input_dim for dist in self.dists])
        )[:-1]

    def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
        """Calculates log probability density of inputs.

        Parameters
        ----------
        params : Jax Pytree
            Parameters for the distributions.
        inputs : jnp.ndarray
            Input data for which log probability density is calculated.

        Returns
        -------
        jnp.ndarray
            Device array of shape (inputs.shape[0],).
        """

        # split inputs for corresponding distribution
        inputs = jnp.split(inputs, self._splits, axis=1)

        # calculate log_prob with respect to each sub-distribution,
        # then sum all the log_probs for each input
        log_prob = jnp.hstack(
            [
                self.dists[i].log_prob(params[i], inputs[i]).reshape(-1, 1)
                for i in range(len(self.dists))
            ]
        ).sum(axis=1)

        return log_prob

    def sample(
        self, params: Pytree, nsamples: int, seed: int = None
    ) -> jnp.ndarray:
        """Returns samples from the distribution.

        Parameters
        ----------
        params : a Jax pytree
            Parameters for the distributions.
        nsamples : int
            The number of samples to be returned.
        seed : int; optional
            Sets the random seed for the samples.

        Returns
        -------
        jnp.ndarray
            Device array of shape (nsamples, self.input_dim).
        """

        seed = np.random.randint(1e18) if seed is None else seed
        seeds = random.randint(
            random.PRNGKey(seed), (len(self.dists),), 0, int(1e9)
        )
        samples = jnp.hstack(
            [
                self.dists[i]
                .sample(params[i], nsamples, seeds[i])
                .reshape(nsamples, -1)
                for i in range(len(self.dists))
            ]
        )

        return samples

__init__(inputs)

Parameters:

Name Type Description Default
inputs Union[LatentDist, tuple]

The latent distributions to join together.

()
Source code in pzflow/distributions.py
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
def __init__(self, *inputs: Union[LatentDist, tuple]) -> None:
    """
    Parameters
    ----------
    inputs: LatentDist or tuple
        The latent distributions to join together.
    """

    # if Joint info provided, use that for setup
    if inputs[0] == "Joint info":
        self.dists = [globals()[dist[0]](*dist[1]) for dist in inputs[1]]
    # otherwise, assume it's a list of distributions
    else:
        self.dists = inputs

    # save info
    self._params = [dist._params for dist in self.dists]
    self.input_dim = sum([dist.input_dim for dist in self.dists])
    self.info = (
        "Joint",
        ("Joint info", [dist.info for dist in self.dists]),
    )

    # save the indices at which inputs will be split for log_prob
    # they must be concretely saved ahead-of-time so that jax trace
    # works properly when jitting
    self._splits = jnp.cumsum(
        jnp.array([dist.input_dim for dist in self.dists])
    )[:-1]

log_prob(params, inputs)

Calculates log probability density of inputs.

Parameters:

Name Type Description Default
params Jax Pytree

Parameters for the distributions.

required
inputs jnp.ndarray

Input data for which log probability density is calculated.

required

Returns:

Type Description
jnp.ndarray

Device array of shape (inputs.shape[0],).

Source code in pzflow/distributions.py
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
    """Calculates log probability density of inputs.

    Parameters
    ----------
    params : Jax Pytree
        Parameters for the distributions.
    inputs : jnp.ndarray
        Input data for which log probability density is calculated.

    Returns
    -------
    jnp.ndarray
        Device array of shape (inputs.shape[0],).
    """

    # split inputs for corresponding distribution
    inputs = jnp.split(inputs, self._splits, axis=1)

    # calculate log_prob with respect to each sub-distribution,
    # then sum all the log_probs for each input
    log_prob = jnp.hstack(
        [
            self.dists[i].log_prob(params[i], inputs[i]).reshape(-1, 1)
            for i in range(len(self.dists))
        ]
    ).sum(axis=1)

    return log_prob

sample(params, nsamples, seed=None)

Returns samples from the distribution.

Parameters:

Name Type Description Default
params a Jax pytree

Parameters for the distributions.

required
nsamples int

The number of samples to be returned.

required
seed int

Sets the random seed for the samples.

None

Returns:

Type Description
jnp.ndarray

Device array of shape (nsamples, self.input_dim).

Source code in pzflow/distributions.py
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
def sample(
    self, params: Pytree, nsamples: int, seed: int = None
) -> jnp.ndarray:
    """Returns samples from the distribution.

    Parameters
    ----------
    params : a Jax pytree
        Parameters for the distributions.
    nsamples : int
        The number of samples to be returned.
    seed : int; optional
        Sets the random seed for the samples.

    Returns
    -------
    jnp.ndarray
        Device array of shape (nsamples, self.input_dim).
    """

    seed = np.random.randint(1e18) if seed is None else seed
    seeds = random.randint(
        random.PRNGKey(seed), (len(self.dists),), 0, int(1e9)
    )
    samples = jnp.hstack(
        [
            self.dists[i]
            .sample(params[i], nsamples, seeds[i])
            .reshape(nsamples, -1)
            for i in range(len(self.dists))
        ]
    )

    return samples

LatentDist

Bases: ABC

Base class for latent distributions.

Source code in pzflow/distributions.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class LatentDist(ABC):
    """Base class for latent distributions."""

    info = ("LatentDist", ())

    @abstractmethod
    def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
        """Calculate log-probability of the inputs."""

    @abstractmethod
    def sample(
        self, params: Pytree, nsamples: int, seed: int = None
    ) -> jnp.ndarray:
        """Sample from the distribution."""

log_prob(params, inputs) abstractmethod

Calculate log-probability of the inputs.

Source code in pzflow/distributions.py
22
23
24
@abstractmethod
def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
    """Calculate log-probability of the inputs."""

sample(params, nsamples, seed=None) abstractmethod

Sample from the distribution.

Source code in pzflow/distributions.py
26
27
28
29
30
@abstractmethod
def sample(
    self, params: Pytree, nsamples: int, seed: int = None
) -> jnp.ndarray:
    """Sample from the distribution."""

Normal

Bases: LatentDist

A multivariate Gaussian distribution with mean zero and unit variance.

Note this distribution has infinite support, so it is not recommended that you use it with the spline coupling layers, which have compact support. If you do use the two together, you should set the support of the spline layers (using the spline parameter B) to be large enough that you rarely draw Gaussian samples outside the support of the splines.

Source code in pzflow/distributions.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
class Normal(LatentDist):
    """A multivariate Gaussian distribution with mean zero and unit variance.

    Note this distribution has infinite support, so it is not recommended that
    you use it with the spline coupling layers, which have compact support.
    If you do use the two together, you should set the support of the spline
    layers (using the spline parameter B) to be large enough that you rarely
    draw Gaussian samples outside the support of the splines.
    """

    def __init__(self, input_dim: int) -> None:
        """
        Parameters
        ----------
        input_dim : int
            The dimension of the distribution.
        """
        self.input_dim = input_dim

        # save dist info
        self._params = ()
        self.info = ("Normal", (input_dim,))

    def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
        """Calculates log probability density of inputs.

        Parameters
        ----------
        params : a Jax pytree
            Empty pytree -- this distribution doesn't have learnable parameters.
            This parameter is present to ensure a consistent interface.
        inputs : jnp.ndarray
            Input data for which log probability density is calculated.

        Returns
        -------
        jnp.ndarray
            Device array of shape (inputs.shape[0],).
        """
        return multivariate_normal.logpdf(
            inputs,
            mean=jnp.zeros(self.input_dim),
            cov=jnp.identity(self.input_dim),
        )

    def sample(
        self, params: Pytree, nsamples: int, seed: int = None
    ) -> jnp.ndarray:
        """Returns samples from the distribution.

        Parameters
        ----------
        params : a Jax pytree
            Empty pytree -- this distribution doesn't have learnable parameters.
            This parameter is present to ensure a consistent interface.
        nsamples : int
            The number of samples to be returned.
        seed : int; optional
            Sets the random seed for the samples.

        Returns
        -------
        jnp.ndarray
            Device array of shape (nsamples, self.input_dim).
        """
        seed = np.random.randint(1e18) if seed is None else seed
        return random.multivariate_normal(
            key=random.PRNGKey(seed),
            mean=jnp.zeros(self.input_dim),
            cov=jnp.identity(self.input_dim),
            shape=(nsamples,),
        )

__init__(input_dim)

Parameters:

Name Type Description Default
input_dim int

The dimension of the distribution.

required
Source code in pzflow/distributions.py
242
243
244
245
246
247
248
249
250
251
252
253
def __init__(self, input_dim: int) -> None:
    """
    Parameters
    ----------
    input_dim : int
        The dimension of the distribution.
    """
    self.input_dim = input_dim

    # save dist info
    self._params = ()
    self.info = ("Normal", (input_dim,))

log_prob(params, inputs)

Calculates log probability density of inputs.

Parameters:

Name Type Description Default
params a Jax pytree

Empty pytree -- this distribution doesn't have learnable parameters. This parameter is present to ensure a consistent interface.

required
inputs jnp.ndarray

Input data for which log probability density is calculated.

required

Returns:

Type Description
jnp.ndarray

Device array of shape (inputs.shape[0],).

Source code in pzflow/distributions.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
    """Calculates log probability density of inputs.

    Parameters
    ----------
    params : a Jax pytree
        Empty pytree -- this distribution doesn't have learnable parameters.
        This parameter is present to ensure a consistent interface.
    inputs : jnp.ndarray
        Input data for which log probability density is calculated.

    Returns
    -------
    jnp.ndarray
        Device array of shape (inputs.shape[0],).
    """
    return multivariate_normal.logpdf(
        inputs,
        mean=jnp.zeros(self.input_dim),
        cov=jnp.identity(self.input_dim),
    )

sample(params, nsamples, seed=None)

Returns samples from the distribution.

Parameters:

Name Type Description Default
params a Jax pytree

Empty pytree -- this distribution doesn't have learnable parameters. This parameter is present to ensure a consistent interface.

required
nsamples int

The number of samples to be returned.

required
seed int

Sets the random seed for the samples.

None

Returns:

Type Description
jnp.ndarray

Device array of shape (nsamples, self.input_dim).

Source code in pzflow/distributions.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def sample(
    self, params: Pytree, nsamples: int, seed: int = None
) -> jnp.ndarray:
    """Returns samples from the distribution.

    Parameters
    ----------
    params : a Jax pytree
        Empty pytree -- this distribution doesn't have learnable parameters.
        This parameter is present to ensure a consistent interface.
    nsamples : int
        The number of samples to be returned.
    seed : int; optional
        Sets the random seed for the samples.

    Returns
    -------
    jnp.ndarray
        Device array of shape (nsamples, self.input_dim).
    """
    seed = np.random.randint(1e18) if seed is None else seed
    return random.multivariate_normal(
        key=random.PRNGKey(seed),
        mean=jnp.zeros(self.input_dim),
        cov=jnp.identity(self.input_dim),
        shape=(nsamples,),
    )

Tdist

Bases: LatentDist

A multivariate T distribution with mean zero and unit scale matrix.

The number of degrees of freedom (i.e. the weight of the tails) is learned during training.

Note this distribution has infinite support and potentially large tails, so it is not recommended to use this distribution with the spline coupling layers, which have compact support.

Source code in pzflow/distributions.py
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
class Tdist(LatentDist):
    """A multivariate T distribution with mean zero and unit scale matrix.

    The number of degrees of freedom (i.e. the weight of the tails) is learned
    during training.

    Note this distribution has infinite support and potentially large tails,
    so it is not recommended to use this distribution with the spline coupling
    layers, which have compact support.
    """

    def __init__(self, input_dim: int) -> None:
        """
        Parameters
        ----------
        input_dim : int
            The dimension of the distribution.
        """
        self.input_dim = input_dim

        # save dist info
        self._params = jnp.log(30.0)
        self.info = ("Tdist", (input_dim,))

    def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
        """Calculates log probability density of inputs.

        Uses method explained here:
        http://gregorygundersen.com/blog/2020/01/20/multivariate-t/

        Parameters
        ----------
        params : float
            The degrees of freedom (nu) of the t-distribution.
        inputs : jnp.ndarray
            Input data for which log probability density is calculated.

        Returns
        -------
        jnp.ndarray
            Device array of shape (inputs.shape[0],).
        """
        cov = jnp.identity(self.input_dim)
        nu = jnp.exp(params)
        maha, log_det = _mahalanobis_and_logdet(inputs, cov)
        t = 0.5 * (nu + self.input_dim)
        A = gammaln(t)
        B = gammaln(0.5 * nu)
        C = self.input_dim / 2.0 * jnp.log(nu * jnp.pi)
        D = 0.5 * log_det
        E = -t * jnp.log(1 + (1.0 / nu) * maha)

        return A - B - C - D + E

    def sample(
        self, params: Pytree, nsamples: int, seed: int = None
    ) -> jnp.ndarray:
        """Returns samples from the distribution.

        Parameters
        ----------
        params : float
            The degrees of freedom (nu) of the t-distribution.
        nsamples : int
            The number of samples to be returned.
        seed : int; optional
            Sets the random seed for the samples.

        Returns
        -------
        jnp.ndarray
            Device array of shape (nsamples, self.input_dim).
        """
        mean = jnp.zeros(self.input_dim)
        nu = jnp.exp(params)

        seed = np.random.randint(1e18) if seed is None else seed
        rng = np.random.default_rng(int(seed))
        x = jnp.array(rng.chisquare(nu, nsamples) / nu)
        z = random.multivariate_normal(
            key=random.PRNGKey(seed),
            mean=jnp.zeros(self.input_dim),
            cov=jnp.identity(self.input_dim),
            shape=(nsamples,),
        )
        samples = mean + z / jnp.sqrt(x)[:, None]
        return samples

__init__(input_dim)

Parameters:

Name Type Description Default
input_dim int

The dimension of the distribution.

required
Source code in pzflow/distributions.py
317
318
319
320
321
322
323
324
325
326
327
328
def __init__(self, input_dim: int) -> None:
    """
    Parameters
    ----------
    input_dim : int
        The dimension of the distribution.
    """
    self.input_dim = input_dim

    # save dist info
    self._params = jnp.log(30.0)
    self.info = ("Tdist", (input_dim,))

log_prob(params, inputs)

Calculates log probability density of inputs.

Uses method explained here: http://gregorygundersen.com/blog/2020/01/20/multivariate-t/

Parameters:

Name Type Description Default
params float

The degrees of freedom (nu) of the t-distribution.

required
inputs jnp.ndarray

Input data for which log probability density is calculated.

required

Returns:

Type Description
jnp.ndarray

Device array of shape (inputs.shape[0],).

Source code in pzflow/distributions.py
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
    """Calculates log probability density of inputs.

    Uses method explained here:
    http://gregorygundersen.com/blog/2020/01/20/multivariate-t/

    Parameters
    ----------
    params : float
        The degrees of freedom (nu) of the t-distribution.
    inputs : jnp.ndarray
        Input data for which log probability density is calculated.

    Returns
    -------
    jnp.ndarray
        Device array of shape (inputs.shape[0],).
    """
    cov = jnp.identity(self.input_dim)
    nu = jnp.exp(params)
    maha, log_det = _mahalanobis_and_logdet(inputs, cov)
    t = 0.5 * (nu + self.input_dim)
    A = gammaln(t)
    B = gammaln(0.5 * nu)
    C = self.input_dim / 2.0 * jnp.log(nu * jnp.pi)
    D = 0.5 * log_det
    E = -t * jnp.log(1 + (1.0 / nu) * maha)

    return A - B - C - D + E

sample(params, nsamples, seed=None)

Returns samples from the distribution.

Parameters:

Name Type Description Default
params float

The degrees of freedom (nu) of the t-distribution.

required
nsamples int

The number of samples to be returned.

required
seed int

Sets the random seed for the samples.

None

Returns:

Type Description
jnp.ndarray

Device array of shape (nsamples, self.input_dim).

Source code in pzflow/distributions.py
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def sample(
    self, params: Pytree, nsamples: int, seed: int = None
) -> jnp.ndarray:
    """Returns samples from the distribution.

    Parameters
    ----------
    params : float
        The degrees of freedom (nu) of the t-distribution.
    nsamples : int
        The number of samples to be returned.
    seed : int; optional
        Sets the random seed for the samples.

    Returns
    -------
    jnp.ndarray
        Device array of shape (nsamples, self.input_dim).
    """
    mean = jnp.zeros(self.input_dim)
    nu = jnp.exp(params)

    seed = np.random.randint(1e18) if seed is None else seed
    rng = np.random.default_rng(int(seed))
    x = jnp.array(rng.chisquare(nu, nsamples) / nu)
    z = random.multivariate_normal(
        key=random.PRNGKey(seed),
        mean=jnp.zeros(self.input_dim),
        cov=jnp.identity(self.input_dim),
        shape=(nsamples,),
    )
    samples = mean + z / jnp.sqrt(x)[:, None]
    return samples

Uniform

Bases: LatentDist

A multivariate uniform distribution with support [-B, B].

Source code in pzflow/distributions.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
class Uniform(LatentDist):
    """A multivariate uniform distribution with support [-B, B]."""

    def __init__(self, input_dim: int, B: float = 5) -> None:
        """
        Parameters
        ----------
        input_dim : int
            The dimension of the distribution.
        B : float; default=5
            The distribution has support (-B, B) along each dimension.
        """
        self.input_dim = input_dim
        self.B = B

        # save dist info
        self._params = ()
        self.info = ("Uniform", (input_dim, B))

    def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
        """Calculates log probability density of inputs.

        Parameters
        ----------
        params : Jax Pytree
            Empty pytree -- this distribution doesn't have learnable parameters.
            This parameter is present to ensure a consistent interface.
        inputs : jnp.ndarray
            Input data for which log probability density is calculated.

        Returns
        -------
        jnp.ndarray
            Device array of shape (inputs.shape[0],).
        """

        # which inputs are inside the support of the distribution
        mask = jnp.prod((inputs >= -self.B) & (inputs <= self.B), axis=-1)

        # calculate log_prob
        log_prob = jnp.where(
            mask,
            -self.input_dim * jnp.log(2 * self.B),
            -jnp.inf,
        )

        return log_prob

    def sample(
        self, params: Pytree, nsamples: int, seed: int = None
    ) -> jnp.ndarray:
        """Returns samples from the distribution.

        Parameters
        ----------
        params : a Jax pytree
            Empty pytree -- this distribution doesn't have learnable parameters.
            This parameter is present to ensure a consistent interface.
        nsamples : int
            The number of samples to be returned.
        seed : int; optional
            Sets the random seed for the samples.

        Returns
        -------
        jnp.ndarray
            Device array of shape (nsamples, self.input_dim).
        """
        seed = np.random.randint(1e18) if seed is None else seed
        samples = random.uniform(
            random.PRNGKey(seed),
            shape=(nsamples, self.input_dim),
            minval=-self.B,
            maxval=self.B,
        )
        return jnp.array(samples)

__init__(input_dim, B=5)

Parameters:

Name Type Description Default
input_dim int

The dimension of the distribution.

required
B float

The distribution has support (-B, B) along each dimension.

5
Source code in pzflow/distributions.py
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def __init__(self, input_dim: int, B: float = 5) -> None:
    """
    Parameters
    ----------
    input_dim : int
        The dimension of the distribution.
    B : float; default=5
        The distribution has support (-B, B) along each dimension.
    """
    self.input_dim = input_dim
    self.B = B

    # save dist info
    self._params = ()
    self.info = ("Uniform", (input_dim, B))

log_prob(params, inputs)

Calculates log probability density of inputs.

Parameters:

Name Type Description Default
params Jax Pytree

Empty pytree -- this distribution doesn't have learnable parameters. This parameter is present to ensure a consistent interface.

required
inputs jnp.ndarray

Input data for which log probability density is calculated.

required

Returns:

Type Description
jnp.ndarray

Device array of shape (inputs.shape[0],).

Source code in pzflow/distributions.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
def log_prob(self, params: Pytree, inputs: jnp.ndarray) -> jnp.ndarray:
    """Calculates log probability density of inputs.

    Parameters
    ----------
    params : Jax Pytree
        Empty pytree -- this distribution doesn't have learnable parameters.
        This parameter is present to ensure a consistent interface.
    inputs : jnp.ndarray
        Input data for which log probability density is calculated.

    Returns
    -------
    jnp.ndarray
        Device array of shape (inputs.shape[0],).
    """

    # which inputs are inside the support of the distribution
    mask = jnp.prod((inputs >= -self.B) & (inputs <= self.B), axis=-1)

    # calculate log_prob
    log_prob = jnp.where(
        mask,
        -self.input_dim * jnp.log(2 * self.B),
        -jnp.inf,
    )

    return log_prob

sample(params, nsamples, seed=None)

Returns samples from the distribution.

Parameters:

Name Type Description Default
params a Jax pytree

Empty pytree -- this distribution doesn't have learnable parameters. This parameter is present to ensure a consistent interface.

required
nsamples int

The number of samples to be returned.

required
seed int

Sets the random seed for the samples.

None

Returns:

Type Description
jnp.ndarray

Device array of shape (nsamples, self.input_dim).

Source code in pzflow/distributions.py
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
def sample(
    self, params: Pytree, nsamples: int, seed: int = None
) -> jnp.ndarray:
    """Returns samples from the distribution.

    Parameters
    ----------
    params : a Jax pytree
        Empty pytree -- this distribution doesn't have learnable parameters.
        This parameter is present to ensure a consistent interface.
    nsamples : int
        The number of samples to be returned.
    seed : int; optional
        Sets the random seed for the samples.

    Returns
    -------
    jnp.ndarray
        Device array of shape (nsamples, self.input_dim).
    """
    seed = np.random.randint(1e18) if seed is None else seed
    samples = random.uniform(
        random.PRNGKey(seed),
        shape=(nsamples, self.input_dim),
        minval=-self.B,
        maxval=self.B,
    )
    return jnp.array(samples)

examples

Functions that return example data and a example flow trained on galaxy data. To see these examples in action, see the tutorial notebooks.

get_checkerboard_data()

Return DataFrame with discrete checkerboard data.

Source code in pzflow/examples.py
45
46
47
def get_checkerboard_data() -> pd.DataFrame:
    """Return DataFrame with discrete checkerboard data."""
    return _load_example_data("checkerboard-data")

get_city_data()

Return DataFrame with example city data.

The countries, names, population, and coordinates of 47,966 cities.

Subset of the Kaggle world cities database. https://www.kaggle.com/max-mind/world-cities-database This database was downloaded from MaxMind. The license follows:

OPEN DATA LICENSE for MaxMind WorldCities and Postal Code Databases

Copyright (c) 2008 MaxMind Inc.  All Rights Reserved.

The database uses toponymic information, based on the Geographic Names
Data Base, containing official standard names approved by the United States
Board on Geographic Names and maintained by the National
Geospatial-Intelligence Agency. More information is available at the Maps
and Geodata link at www.nga.mil. The National Geospatial-Intelligence Agency
name, initials, and seal are protected by 10 United States Code Section 445.

It also uses free population data from Stefan Helders www.world-gazetteer.com.
Visit his website to download the free population data.  Our database
combines Stefan's population data with the list of all cities in the world.

All advertising materials and documentation mentioning features or use of
this database must display the following acknowledgment:
"This product includes data created by MaxMind, available from
http://www.maxmind.com/"

Redistribution and use with or without modification, are permitted provided
that the following conditions are met:
1. Redistributions must retain the above copyright notice, this list of
conditions and the following disclaimer in the documentation and/or other
materials provided with the distribution.
2. All advertising materials and documentation mentioning features or use of
this database must display the following acknowledgement:
"This product includes data created by MaxMind, available from
http://www.maxmind.com/"
3. "MaxMind" may not be used to endorse or promote products derived from this
database without specific prior written permission.

THIS DATABASE IS PROVIDED BY MAXMIND.COM ``AS IS'' AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL MAXMIND.COM BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
DATABASE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Source code in pzflow/examples.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def get_city_data() -> pd.DataFrame:
    """Return DataFrame with example city data.

    The countries, names, population, and coordinates of 47,966 cities.

    Subset of the Kaggle world cities database.
    https://www.kaggle.com/max-mind/world-cities-database
    This database was downloaded from MaxMind. The license follows:

        OPEN DATA LICENSE for MaxMind WorldCities and Postal Code Databases

        Copyright (c) 2008 MaxMind Inc.  All Rights Reserved.

        The database uses toponymic information, based on the Geographic Names
        Data Base, containing official standard names approved by the United States
        Board on Geographic Names and maintained by the National
        Geospatial-Intelligence Agency. More information is available at the Maps
        and Geodata link at www.nga.mil. The National Geospatial-Intelligence Agency
        name, initials, and seal are protected by 10 United States Code Section 445.

        It also uses free population data from Stefan Helders www.world-gazetteer.com.
        Visit his website to download the free population data.  Our database
        combines Stefan's population data with the list of all cities in the world.

        All advertising materials and documentation mentioning features or use of
        this database must display the following acknowledgment:
        "This product includes data created by MaxMind, available from
        http://www.maxmind.com/"

        Redistribution and use with or without modification, are permitted provided
        that the following conditions are met:
        1. Redistributions must retain the above copyright notice, this list of
        conditions and the following disclaimer in the documentation and/or other
        materials provided with the distribution.
        2. All advertising materials and documentation mentioning features or use of
        this database must display the following acknowledgement:
        "This product includes data created by MaxMind, available from
        http://www.maxmind.com/"
        3. "MaxMind" may not be used to endorse or promote products derived from this
        database without specific prior written permission.

        THIS DATABASE IS PROVIDED BY MAXMIND.COM ``AS IS'' AND ANY
        EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
        WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
        DISCLAIMED. IN NO EVENT SHALL MAXMIND.COM BE LIABLE FOR ANY
        DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
        (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
        LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
        ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
        (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
        DATABASE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    """
    return _load_example_data("city-data")

get_example_flow()

Return a normalizing flow that was trained on galaxy data.

This flow was trained in the redshift_example.ipynb Jupyter notebook, on the example data available in pzflow.examples.galaxy_data. For more info: print(example_flow().info).

Source code in pzflow/examples.py
105
106
107
108
109
110
111
112
113
114
115
def get_example_flow() -> Flow:
    """Return a normalizing flow that was trained on galaxy data.

    This flow was trained in the `redshift_example.ipynb` Jupyter notebook,
    on the example data available in `pzflow.examples.galaxy_data`.
    For more info: `print(example_flow().info)`.
    """
    this_dir, _ = os.path.split(__file__)
    flow_path = os.path.join(this_dir, f"{EXAMPLE_FILE_DIR}/example-flow.pzflow.pkl")
    flow = Flow(file=flow_path)
    return flow

get_galaxy_data()

Return DataFrame with example galaxy data.

100,000 galaxies from the Buzzard simulation [1], with redshifts in the range (0,2.3) and photometry in the LSST ugrizy bands.

References

[1] Joseph DeRose et al. The Buzzard Flock: Dark Energy Survey Synthetic Sky Catalogs. arXiv:1901.02401, 2019. https://arxiv.org/abs/1901.02401

Source code in pzflow/examples.py
30
31
32
33
34
35
36
37
38
39
40
41
42
def get_galaxy_data() -> pd.DataFrame:
    """Return DataFrame with example galaxy data.

    100,000 galaxies from the Buzzard simulation [1], with redshifts
    in the range (0,2.3) and photometry in the LSST ugrizy bands.

    References
    ----------
    [1] Joseph DeRose et al. The Buzzard Flock: Dark Energy Survey
    Synthetic Sky Catalogs. arXiv:1901.02401, 2019.
    https://arxiv.org/abs/1901.02401
    """
    return _load_example_data("galaxy-data")

get_twomoons_data()

Return DataFrame with two moons example data.

Two moons data originally from scikit-learn, i.e., sklearn.datasets.make_moons.

Source code in pzflow/examples.py
21
22
23
24
25
26
27
def get_twomoons_data() -> pd.DataFrame:
    """Return DataFrame with two moons example data.

    Two moons data originally from scikit-learn,
    i.e., `sklearn.datasets.make_moons`.
    """
    return _load_example_data("two-moons-data")

flow

Define the Flow object that defines the normalizing flow.

Flow

A normalizing flow that models tabular data.

Attributes:

Name Type Description
data_columns tuple

List of DataFrame columns that the flow expects/produces.

conditional_columns tuple

List of DataFrame columns on which the flow is conditioned.

latent distributions.LatentDist

The latent distribution of the normalizing flow. Has it's own sample and log_prob methods.

data_error_model Callable

The error model for the data variables. See the docstring of init for more details.

condition_error_model Callable

The error model for the conditional variables. See the docstring of init for more details.

info Any

Object containing any kind of info included with the flow. Often describes the data the flow is trained on.

Source code in pzflow/flow.py
  24
  25
  26
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
class Flow:
    """A normalizing flow that models tabular data.

    Attributes
    ----------
    data_columns : tuple
        List of DataFrame columns that the flow expects/produces.
    conditional_columns : tuple
        List of DataFrame columns on which the flow is conditioned.
    latent : distributions.LatentDist
        The latent distribution of the normalizing flow.
        Has it's own sample and log_prob methods.
    data_error_model : Callable
        The error model for the data variables. See the docstring of
        __init__ for more details.
    condition_error_model : Callable
        The error model for the conditional variables. See the docstring
        of __init__ for more details.
    info : Any
        Object containing any kind of info included with the flow.
        Often describes the data the flow is trained on.
    """

    def __init__(
        self,
        data_columns: Sequence[str] = None,
        bijector: Tuple[InitFunction, Bijector_Info] = None,
        latent: distributions.LatentDist = None,
        conditional_columns: Sequence[str] = None,
        data_error_model: Callable = None,
        condition_error_model: Callable = None,
        autoscale_conditions: bool = True,
        seed: int = 0,
        info: Any = None,
        file: str = None,
        _dictionary: dict = None,
    ) -> None:
        """Instantiate a normalizing flow.

        Note that while all of the init parameters are technically optional,
        you must provide either data_columns OR file.
        In addition, if a file is provided, all other parameters must be None.

        Parameters
        ----------
        data_columns : Sequence[str]; optional
            Tuple, list, or other container of column names.
            These are the columns the flow expects/produces in DataFrames.
        bijector : Bijector Call; optional
            A Bijector call that consists of the bijector InitFunction that
            initializes the bijector and the tuple of Bijector Info.
            Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.
            If not provided, the bijector can be set later using
            flow.set_bijector, or by calling flow.train, in which case the
            default bijector will be used. The default bijector is
            ShiftBounds -> RollingSplineCoupling, where the range of shift
            bounds is learned from the training data, and the dimensions of
            RollingSplineCoupling is inferred. The default bijector assumes
            that the latent has support [-5, 5] for every dimension.
        latent : distributions.LatentDist; optional
            The latent distribution for the normalizing flow. Can be any of
            the distributions from pzflow.distributions. If not provided,
            a uniform distribution is used with input_dim = len(data_columns),
            and B=5.
        conditional_columns : Sequence[str]; optional
            Names of columns on which to condition the normalizing flow.
        data_error_model : Callable; optional
            A callable that defines the error model for data variables.
            data_error_model must take key, X, Xerr, nsamples as arguments:
                - key is a jax rng key, e.g. jax.random.PRNGKey(0)
                - X is 2D array of data variables, where the order of variables
                    matches the order of the columns in data_columns
                - Xerr is the corresponding 2D array of errors
                - nsamples is number of samples to draw from error distribution
            data_error_model must return an array of samples with the shape
            (X.shape[0], nsamples, X.shape[1]).
            If data_error_model is not provided, Gaussian error model assumed.
        condition_error_model : Callable; optional
            A callable that defines the error model for conditional variables.
            condition_error_model must take key, X, Xerr, nsamples, where:
                - key is a jax rng key, e.g. jax.random.PRNGKey(0)
                - X is 2D array of conditional variables, where the order of
                    variables matches order of columns in conditional_columns
                - Xerr is the corresponding 2D array of errors
                - nsamples is number of samples to draw from error distribution
            condition_error_model must return array of samples with shape
            (X.shape[0], nsamples, X.shape[1]).
            If condition_error_model is not provided, Gaussian error model
            assumed.
        autoscale_conditions : bool; default=True
            Sets whether or not conditions are automatically standard scaled
            when passed to a conditional flow. I recommend you leave as True.
        seed : int; default=0
            The random seed for initial parameters
        info : Any; optional
            An object to attach to the info attribute.
        file : str; optional
            Path to file from which to load a pretrained flow.
            If a file is provided, all other parameters must be None.
        """

        # validate parameters
        if data_columns is None and file is None and _dictionary is None:
            raise ValueError("You must provide data_columns OR file.")
        if any(
            (
                data_columns is not None,
                bijector is not None,
                conditional_columns is not None,
                latent is not None,
                data_error_model is not None,
                condition_error_model is not None,
                info is not None,
            )
        ):
            if file is not None:
                raise ValueError(
                    "If providing a file, please do not provide any other parameters."
                )
            if _dictionary is not None:
                raise ValueError(
                    "If providing a dictionary, please do not provide any other parameters."
                )
        if file is not None and _dictionary is not None:
            raise ValueError("Only provide file or _dictionary, not both.")

        # if file or dictionary is provided, load everything from it
        if file is not None or _dictionary is not None:
            save_dict = self._save_dict()
            if file is not None:
                with open(file, "rb") as handle:
                    save_dict.update(pickle.load(handle))
            else:
                save_dict.update(_dictionary)

            if save_dict["class"] != self.__class__.__name__:
                raise TypeError(
                    f"This save file isn't a {self.__class__.__name__}. "
                    f"It is a {save_dict['class']}"
                )

            # load columns and dimensions
            self.data_columns = save_dict["data_columns"]
            self.conditional_columns = save_dict["conditional_columns"]
            self._input_dim = len(self.data_columns)
            self.info = save_dict["info"]

            # load the latent distribution
            self._latent_info = save_dict["latent_info"]
            self.latent = getattr(distributions, self._latent_info[0])(
                *self._latent_info[1]
            )

            # load the error models
            self.data_error_model = save_dict["data_error_model"]
            self.condition_error_model = save_dict["condition_error_model"]

            # load the bijector
            self._bijector_info = save_dict["bijector_info"]
            if self._bijector_info is not None:
                init_fun, _ = build_bijector_from_info(self._bijector_info)
                _, self._forward, self._inverse = init_fun(
                    random.PRNGKey(0), self._input_dim
                )
            self._params = save_dict["params"]

            # load the conditional means and stds
            self._condition_means = save_dict["condition_means"]
            self._condition_stds = save_dict["condition_stds"]

            # set whether or not to automatically standard scale any
            # conditions passed to the normalizing flow
            self._autoscale_conditions = save_dict["autoscale_conditions"]

        # if no file is provided, use provided parameters
        else:
            self.data_columns = tuple(data_columns)
            self._input_dim = len(self.data_columns)
            self.info = info

            if conditional_columns is None:
                self.conditional_columns = None
                self._condition_means = None
                self._condition_stds = None
            else:
                self.conditional_columns = tuple(conditional_columns)
                self._condition_means = jnp.zeros(
                    len(self.conditional_columns)
                )
                self._condition_stds = jnp.ones(len(self.conditional_columns))

            # set whether or not to automatically standard scale any
            # conditions passed to the normalizing flow
            self._autoscale_conditions = autoscale_conditions

            # set up the latent distribution
            if latent is None:
                self.latent = distributions.Uniform(self._input_dim, 5)
            else:
                self.latent = latent
            self._latent_info = self.latent.info

            # make sure the latent distribution and data_columns have the
            # same number of dimensions
            if self.latent.input_dim != len(data_columns):
                raise ValueError(
                    f"The latent distribution has {self.latent.input_dim} "
                    f"dimensions, but data_columns has {len(data_columns)} "
                    "dimensions. They must match!"
                )

            # set up the error models
            if data_error_model is None:
                self.data_error_model = gaussian_error_model
            else:
                self.data_error_model = data_error_model
            if condition_error_model is None:
                self.condition_error_model = gaussian_error_model
            else:
                self.condition_error_model = condition_error_model

            # set up the bijector
            if bijector is not None:
                self.set_bijector(bijector, seed=seed)
            # if no bijector was provided, set bijector_info to None
            else:
                self._bijector_info = None

    def _check_bijector(self) -> None:
        if self._bijector_info is None:
            raise ValueError(
                "The bijector has not been set up yet! "
                "You can do this by calling "
                "flow.set_bijector(bijector, params), "
                "or by calling train, in which case the default "
                "bijector will be used."
            )

    def set_bijector(
        self,
        bijector: Tuple[InitFunction, Bijector_Info],
        params: Pytree = None,
        seed: int = 0,
    ) -> None:
        """Set the bijector.

        Parameters
        ----------
        bijector : Bijector Call
            A Bijector call that consists of the bijector InitFunction that
            initializes the bijector and the tuple of Bijector Info.
            Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.
        params : Pytree; optional
            A Pytree of bijector parameters. If not provided, the bijector
            will be initialized with random parameters.
        seed: int; default=0
            A random seed for initializing the bijector with random parameters.
        """

        # set up the bijector
        init_fun, self._bijector_info = bijector
        bijector_params, self._forward, self._inverse = init_fun(
            random.PRNGKey(seed), self._input_dim
        )

        # check if params were passed
        bijector_params = params if params is not None else bijector_params

        # save the bijector params along with the latent params
        self._params = (self.latent._params, bijector_params)

    def _set_default_bijector(
        self, inputs: pd.DataFrame, seed: int = 0
    ) -> None:
        # Set the default bijector
        # which is ShiftBounds -> RollingSplineCoupling

        # get the min/max for each data column
        data = inputs[list(self.data_columns)].to_numpy()
        mins = data.min(axis=0)
        maxs = data.max(axis=0)

        # determine how many conditional columns we have
        n_conditions = (
            0
            if self.conditional_columns is None
            else len(self.conditional_columns)
        )

        self.set_bijector(
            Chain(
                ShiftBounds(mins, maxs, 4.0),
                RollingSplineCoupling(
                    len(self.data_columns), n_conditions=n_conditions
                ),
            ),
            seed=seed,
        )

    def _get_conditions(self, inputs: pd.DataFrame) -> jnp.ndarray:
        # Return an array of the bijector conditions.

        # if this isn't a conditional flow, just return empty conditions
        if self.conditional_columns is None:
            conditions = jnp.zeros((inputs.shape[0], 1))
        # if this a conditional flow, return an array of the conditions
        else:
            columns = list(self.conditional_columns)
            conditions = jnp.array(inputs[columns].to_numpy())
            conditions = (
                conditions - self._condition_means
            ) / self._condition_stds
        return conditions

    def _get_err_samples(
        self,
        key,
        inputs: pd.DataFrame,
        err_samples: int,
        type: str = "data",
        skip: str = None,
    ) -> jnp.ndarray:
        # Draw error samples for each row of inputs.

        X = inputs.copy()

        # get list of columns
        if type == "data":
            columns = list(self.data_columns)
            error_model = self.data_error_model
        elif type == "conditions":
            if self.conditional_columns is None:
                return jnp.zeros((err_samples * X.shape[0], 1))
            else:
                columns = list(self.conditional_columns)
                error_model = self.condition_error_model
        else:
            raise ValueError("type must be `data` or `conditions`.")

        # make sure all relevant variables have error columns
        for col in columns:
            # if errors not provided for the column, fill in zeros
            if f"{col}_err" not in inputs.columns and col != skip:
                X[f"{col}_err"] = jnp.zeros(X.shape[0])
            # if we are skipping this column, fill in nan's
            elif col == skip:
                X[col] = jnp.nan * jnp.zeros(X.shape[0])
                X[f"{col}_err"] = jnp.nan * jnp.zeros(X.shape[0])

        # pull out relevant columns
        err_columns = [col + "_err" for col in columns]
        X, Xerr = jnp.array(X[columns].to_numpy()), jnp.array(
            X[err_columns].to_numpy()
        )

        # generate samples
        Xsamples = error_model(key, X, Xerr, err_samples)
        Xsamples = Xsamples.reshape(X.shape[0] * err_samples, X.shape[1])

        # delete the column corresponding to skip
        if skip is not None:
            idx = columns.index(skip)
            Xsamples = jnp.delete(Xsamples, idx, axis=1)

        # if these are samples of conditions, standard scale them!
        if type == "conditions":
            Xsamples = (
                Xsamples - self._condition_means
            ) / self._condition_stds

        return Xsamples

    def _log_prob(
        self, params: Pytree, inputs: jnp.ndarray, conditions: jnp.ndarray
    ) -> jnp.ndarray:
        # Log prob for arrays.

        # calculate log_prob
        u, log_det = self._forward(params[1], inputs, conditions=conditions)
        log_prob = self.latent.log_prob(params[0], u) + log_det
        # set NaN's to negative infinity (i.e. zero probability)
        log_prob = jnp.nan_to_num(log_prob, nan=jnp.NINF)
        return log_prob

    def log_prob(
        self, inputs: pd.DataFrame, err_samples: int = None, seed: int = None
    ) -> jnp.ndarray:
        """Calculates log probability density of inputs.

        Parameters
        ----------
        inputs : pd.DataFrame
            Input data for which log probability density is calculated.
            Every column in self.data_columns must be present.
            If self.conditional_columns is not None, those must be present
            as well. If other columns are present, they are ignored.
        err_samples : int; default=None
            Number of samples from the error distribution to average over for
            the log_prob calculation. If provided, Gaussian errors are assumed,
            and method will look for error columns in `inputs`. Error columns
            must end in `_err`. E.g. the error column for the variable `u` must
            be `u_err`. Zero error assumed for any missing error columns.
        seed : int; default=None
            Random seed for drawing the samples with Gaussian errors.

        Returns
        -------
        jnp.ndarray
            Device array of shape (inputs.shape[0],).
        """

        # check that the bijector exists
        self._check_bijector()

        if err_samples is None:
            # convert data to an array with columns ordered
            columns = list(self.data_columns)
            X = jnp.array(inputs[columns].to_numpy())
            # get conditions
            conditions = self._get_conditions(inputs)
            # calculate log_prob
            return self._log_prob(self._params, X, conditions)

        else:
            # validate nsamples
            assert isinstance(
                err_samples, int
            ), "err_samples must be a positive integer."
            assert err_samples > 0, "err_samples must be a positive integer."
            # get Gaussian samples
            seed = np.random.randint(1e18) if seed is None else seed
            key = random.PRNGKey(seed)
            X = self._get_err_samples(key, inputs, err_samples, type="data")
            C = self._get_err_samples(
                key, inputs, err_samples, type="conditions"
            )
            # calculate log_probs
            log_probs = self._log_prob(self._params, X, C)
            probs = jnp.exp(log_probs.reshape(-1, err_samples))
            return jnp.log(probs.mean(axis=1))

    def posterior(
        self,
        inputs: pd.DataFrame,
        column: str,
        grid: jnp.ndarray,
        marg_rules: dict = None,
        normalize: bool = True,
        err_samples: int = None,
        seed: int = None,
        batch_size: int = None,
        nan_to_zero: bool = True,
    ) -> jnp.ndarray:
        """Calculates posterior distributions for the provided column.

        Calculates the conditional posterior distribution, assuming the
        data values in the other columns of the DataFrame.

        Parameters
        ----------
        inputs : pd.DataFrame
            Data on which the posterior distributions are conditioned.
            Must have columns matching self.data_columns, *except*
            for the column specified for the posterior (see below).
        column : str
            Name of the column for which the posterior distribution
            is calculated. Must be one of the columns in self.data_columns.
            However, whether or not this column is one of the columns in
            `inputs` is irrelevant.
        grid : jnp.ndarray
            Grid on which to calculate the posterior.
        marg_rules : dict; optional
            Dictionary with rules for marginalizing over missing variables.
            The dictionary must contain the key "flag", which gives the flag
            that indicates a missing value. E.g. if missing values are given
            the value 99, the dictionary should contain {"flag": 99}.
            The dictionary must also contain {"name": callable} for any
            variables that will need to be marginalized over, where name is
            the name of the variable, and callable is a callable that takes
            the row of variables nad returns a grid over which to marginalize
            the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}.
            Note: the callable for a given name must *always* return an array
            of the same length, regardless of the input row.
        err_samples : int; default=None
            Number of samples from the error distribution to average over for
            the posterior calculation. If provided, Gaussian errors are assumed,
            and method will look for error columns in `inputs`. Error columns
            must end in `_err`. E.g. the error column for the variable `u` must
            be `u_err`. Zero error assumed for any missing error columns.
        seed : int; default=None
            Random seed for drawing the samples with Gaussian errors.
        batch_size : int; default=None
            Size of batches in which to calculate posteriors. If None, all
            posteriors are calculated simultaneously. Simultaneous calculation
            is faster, but memory intensive for large data sets.
        normalize : boolean; default=True
            Whether to normalize the posterior so that it integrates to 1.
        nan_to_zero : bool; default=True
            Whether to convert NaN's to zero probability in the final pdfs.

        Returns
        -------
        jnp.ndarray
            Device array of shape (inputs.shape[0], grid.size).
        """

        # check that the bijector exists
        self._check_bijector()

        # get the index of the provided column, and remove it from the list
        columns = list(self.data_columns)
        idx = columns.index(column)
        columns.remove(column)

        nrows = inputs.shape[0]
        batch_size = nrows if batch_size is None else batch_size

        # make sure indices run 0 -> nrows
        inputs = inputs.reset_index(drop=True)

        if err_samples is not None:
            # validate nsamples
            assert isinstance(
                err_samples, int
            ), "err_samples must be a positive integer."
            assert err_samples > 0, "err_samples must be a positive integer."
            # set the seed
            seed = np.random.randint(1e18) if seed is None else seed
            key = random.PRNGKey(seed)

        # empty array to hold pdfs
        pdfs = jnp.zeros((nrows, len(grid)))

        # if marginalization rules were passed, we will loop over the rules
        # and repeatedly call this method
        if marg_rules is not None:
            # if the flag is NaN, we must use jnp.isnan to check for flags
            if np.isnan(marg_rules["flag"]):

                def check_flags(data):
                    return np.isnan(data)

            # else we use jnp.isclose to check for flags
            else:

                def check_flags(data):
                    return np.isclose(data, marg_rules["flag"])

            # first calculate pdfs for unflagged rows
            unflagged_idx = inputs[
                ~check_flags(inputs[columns]).any(axis=1)
            ].index.tolist()
            unflagged_pdfs = self.posterior(
                inputs=inputs.iloc[unflagged_idx],
                column=column,
                grid=grid,
                err_samples=err_samples,
                seed=seed,
                batch_size=batch_size,
                normalize=False,
                nan_to_zero=nan_to_zero,
            )

            # save these pdfs in the big array
            pdfs = pdfs.at[unflagged_idx, :].set(
                unflagged_pdfs,
                indices_are_sorted=True,
                unique_indices=True,
            )

            # we will keep track of all the rows we've already calculated
            # posteriors for
            already_done = unflagged_idx

            # now we will loop over the rules in marg_rules
            for name, rule in marg_rules.items():
                # ignore the flag, because that's not a column in the data
                if name == "flag":
                    continue

                # get the list of new rows for which we need to calculate posteriors
                flagged_idx = inputs[check_flags(inputs[name])].index.tolist()
                flagged_idx = list(set(flagged_idx).difference(already_done))

                # if flagged_idx is empty, move on!
                if len(flagged_idx) == 0:
                    continue

                # get the marginalization grid for each row
                marg_grids = (
                    inputs.iloc[flagged_idx]
                    .apply(rule, axis=1, result_type="expand")
                    .to_numpy()
                )

                # make a new data frame with the marginalization grids replacing
                # the values of the flag in the column
                marg_inputs = pd.DataFrame(
                    np.repeat(
                        inputs.iloc[flagged_idx].to_numpy(),
                        marg_grids.shape[1],
                        axis=0,
                    ),
                    columns=inputs.columns,
                )
                marg_inputs[name] = marg_grids.reshape(marg_inputs.shape[0], 1)

                # remove the error column if it's present
                marg_inputs.drop(
                    f"{name}_err", axis=1, inplace=True, errors="ignore"
                )

                # calculate posteriors for these
                marg_pdfs = self.posterior(
                    inputs=marg_inputs,
                    column=column,
                    grid=grid,
                    marg_rules=marg_rules,
                    err_samples=err_samples,
                    seed=seed,
                    batch_size=batch_size,
                    normalize=False,
                    nan_to_zero=nan_to_zero,
                )

                # sum over the marginalized dimension
                marg_pdfs = marg_pdfs.reshape(
                    len(flagged_idx), marg_grids.shape[1], grid.size
                )
                marg_pdfs = marg_pdfs.sum(axis=1)

                # save the new pdfs in the big array
                pdfs = pdfs.at[flagged_idx, :].set(
                    marg_pdfs,
                    indices_are_sorted=True,
                    unique_indices=True,
                )

                # add these flagged indices to the list of rows already done
                already_done += flagged_idx

        # now for the main posterior calculation loop
        else:
            # loop through batches
            for batch_idx in range(0, nrows, batch_size):
                # get the data batch
                # and, if this is a conditional flow, the correpsonding conditions
                batch = inputs.iloc[batch_idx : batch_idx + batch_size]

                # if not drawing samples, just grab batch and conditions
                if err_samples is None:
                    conditions = self._get_conditions(batch)
                    batch = jnp.array(batch[columns].to_numpy())
                # if only drawing condition samples...
                elif len(self.data_columns) == 1:
                    conditions = self._get_err_samples(
                        key, batch, err_samples, type="conditions"
                    )
                    batch = jnp.repeat(
                        batch[columns].to_numpy(), err_samples, axis=0
                    )
                # if drawing data and condition samples...
                else:
                    conditions = self._get_err_samples(
                        key, batch, err_samples, type="conditions"
                    )
                    batch = self._get_err_samples(
                        key, batch, err_samples, skip=column, type="data"
                    )

                # make a new copy of each row for each value of the column
                # for which we are calculating the posterior
                batch = jnp.hstack(
                    (
                        jnp.repeat(
                            batch[:, :idx],
                            len(grid),
                            axis=0,
                        ),
                        jnp.tile(grid, len(batch))[:, None],
                        jnp.repeat(
                            batch[:, idx:],
                            len(grid),
                            axis=0,
                        ),
                    )
                )

                # make similar copies of the conditions
                conditions = jnp.repeat(conditions, len(grid), axis=0)

                # calculate probability densities
                log_prob = self._log_prob(
                    self._params, batch, conditions
                ).reshape((-1, len(grid)))
                prob = jnp.exp(log_prob)
                # if we were Gaussian sampling, average over the samples
                if err_samples is not None:
                    prob = prob.reshape(-1, err_samples, len(grid))
                    prob = prob.mean(axis=1)
                # add the pdfs to the bigger list
                pdfs = pdfs.at[batch_idx : batch_idx + batch_size, :].set(
                    prob,
                    indices_are_sorted=True,
                    unique_indices=True,
                )

        if normalize:
            # normalize so they integrate to one
            pdfs = pdfs / jnp.trapz(y=pdfs, x=grid).reshape(-1, 1)
        if nan_to_zero:
            # set NaN's equal to zero probability
            pdfs = jnp.nan_to_num(pdfs, nan=0.0)
        return pdfs

    def sample(
        self,
        nsamples: int = 1,
        conditions: pd.DataFrame = None,
        save_conditions: bool = True,
        seed: int = None,
    ) -> pd.DataFrame:
        """Returns samples from the normalizing flow.

        Parameters
        ----------
        nsamples : int; default=1
            The number of samples to be returned.
        conditions : pd.DataFrame; optional
            If this is a conditional flow, you must pass conditions for
            each sample. nsamples will be drawn for each row in conditions.
        save_conditions : bool; default=True
            If true, conditions will be saved in the DataFrame of samples
            that is returned.
        seed : int; optional
            Sets the random seed for the samples.

        Returns
        -------
        pd.DataFrame
            Pandas DataFrame of samples.
        """

        # check that the bijector exists
        self._check_bijector()

        # validate nsamples
        assert isinstance(
            nsamples, int
        ), "nsamples must be a positive integer."
        assert nsamples > 0, "nsamples must be a positive integer."

        if self.conditional_columns is not None and conditions is None:
            raise ValueError(
                f"Must provide the following conditions\n{self.conditional_columns}"
            )

        # if this isn't a conditional flow, get empty conditions
        if self.conditional_columns is None:
            conditions = jnp.zeros((nsamples, 1))
        # otherwise get conditions and make `nsamples` copies of each
        else:
            conditions_idx = list(conditions.index)
            conditions = self._get_conditions(conditions)
            conditions_idx = np.repeat(conditions_idx, nsamples)
            conditions = jnp.repeat(conditions, nsamples, axis=0)

        # draw from latent distribution
        u = self.latent.sample(self._params[0], conditions.shape[0], seed)
        # take the inverse back to the data distribution
        x = self._inverse(self._params[1], u, conditions=conditions)[0]
        # if not conditional, this is all we need
        if self.conditional_columns is None:
            x = pd.DataFrame(np.array(x), columns=self.data_columns)
        # but if conditional
        else:
            if save_conditions:
                # unscale the conditions
                conditions = (
                    conditions * self._condition_stds + self._condition_means
                )
                x = pd.DataFrame(
                    np.array(jnp.hstack((x, conditions))),
                    columns=self.data_columns + self.conditional_columns,
                ).set_index(conditions_idx)
            else:
                # reindex according to the conditions
                x = pd.DataFrame(
                    np.array(x), columns=self.data_columns
                ).set_index(conditions_idx)

        # return the samples!
        return x

    def _save_dict(self) -> None:
        ### Returns the dictionary of all flow params to be saved.
        save_dict = {"class": self.__class__.__name__}
        keys = [
            "data_columns",
            "conditional_columns",
            "condition_means",
            "condition_stds",
            "data_error_model",
            "condition_error_model",
            "autoscale_conditions",
            "info",
            "latent_info",
            "bijector_info",
            "params",
        ]
        for key in keys:
            try:
                save_dict[key] = getattr(self, key)
            except AttributeError:
                try:
                    save_dict[key] = getattr(self, "_" + key)
                except AttributeError:
                    save_dict[key] = None

        return save_dict

    def save(self, file: str) -> None:
        """Saves the flow to a file.

        Pickles the flow and saves it to a file that can be passed as
        the `file` argument during flow instantiation.

        WARNING: Currently, this method only works for bijectors that are
        implemented in the `bijectors` module. If you want to save a flow
        with a custom bijector, you either need to add the bijector to that
        module, or handle the saving and loading on your end.

        Parameters
        ----------
        file : str
            Path to where the flow will be saved.
            Extension `.pkl` will be appended if not already present.
        """
        save_dict = self._save_dict()

        with open(file, "wb") as handle:
            pickle.dump(save_dict, handle, recurse=True)

    def train(
        self,
        inputs: pd.DataFrame,
        val_set: pd.DataFrame = None,
        epochs: int = 100,
        batch_size: int = 1024,
        optimizer: Callable = None,
        loss_fn: Callable = None,
        convolve_errs: bool = False,
        patience: int = None,
        best_params: bool = True,
        seed: int = 0,
        verbose: bool = False,
        progress_bar: bool = False,
    ) -> list:
        """Trains the normalizing flow on the provided inputs.

        Parameters
        ----------
        inputs : pd.DataFrame
            Data on which to train the normalizing flow.
            Must have columns matching `self.data_columns`.
        val_set : pd.DataFrame; default=None
            Validation set, of same format as inputs. If provided,
            validation loss will be calculated at the end of each epoch.
        epochs : int; default=100
            Number of epochs to train.
        batch_size : int; default=1024
            Batch size for training.
        optimizer : optax optimizer
            An optimizer from Optax. default = optax.adam(learning_rate=1e-3)
            see https://optax.readthedocs.io/en/latest/index.html for more.
        loss_fn : Callable; optional
            A function to calculate the loss: `loss = loss_fn(params, x)`.
            If not provided, will be `-mean(log_prob)`.
        convolve_errs : bool; default=False
            Whether to draw new data from the error distributions during
            each epoch of training. Method will look for error columns in
            `inputs`. Error columns must end in `_err`. E.g. the error column
            for the variable `u` must be `u_err`. Zero error assumed for
            any missing error columns. The error distribution is set during
            flow instantiation.
        patience : int; optional
            Factor that controls early stopping. Training will stop if the
            loss doesn't decrease for this number of epochs. Note if a
            validation set is provided, the validation loss is used.
        best_params : bool; default=True
            Whether to use the params from the epoch with the lowest loss.
            Note if a validation set is provided, the epoch with the lowest
            validation loss is chosen. If False, the params from the final
            epoch are saved.
        seed : int; default=0
            A random seed to control the batching and the (optional)
            error sampling and creating the default bijector (the latter
            only happens if you didn't set up the bijector during Flow
            instantiation).
        verbose : bool; default=False
            If true, print the training loss every 5% of epochs.
        progress_bar : bool; default=False
            If true, display a tqdm progress bar during training.

        Returns
        -------
        list
            List of training losses from every epoch. If no val_set provided,
            these are just training losses. If val_set is provided, then the
            first element is the list of training losses, while the second is
            the list of validation losses.
        """

        # split the seed
        rng = np.random.default_rng(seed)
        batch_seed, bijector_seed = rng.integers(1e9, size=2)

        # if the bijector is None, set the default bijector
        if self._bijector_info is None:
            self._set_default_bijector(inputs, seed=bijector_seed)

        # validate epochs
        if not isinstance(epochs, int) or epochs <= 0:
            raise ValueError("epochs must be a positive integer.")

        # if no loss_fn is provided, use the default loss function
        if loss_fn is None:

            @jit
            def loss_fn(params, x, c):
                return -jnp.mean(self._log_prob(params, x, c))

        # initialize the optimizer
        optimizer = (
            optax.adam(learning_rate=1e-3) if optimizer is None else optimizer
        )
        opt_state = optimizer.init(self._params)

        # pull out the model parameters
        model_params = self._params

        # define the training step function
        @jit
        def step(params, opt_state, x, c):
            gradients = grad(loss_fn)(params, x, c)
            updates, opt_state = optimizer.update(gradients, opt_state, params)
            params = optax.apply_updates(params, updates)
            return params, opt_state

        # get list of data columns
        columns = list(self.data_columns)

        # if this is a conditional flow, and autoscale_conditions == True
        # save the means and stds of the conditional columns
        if self.conditional_columns is not None and self._autoscale_conditions:
            self._condition_means = jnp.array(
                inputs[list(self.conditional_columns)].to_numpy().mean(axis=0)
            )
            condition_stds = jnp.array(
                inputs[list(self.conditional_columns)].to_numpy().std(axis=0)
            )
            self._condition_stds = jnp.where(
                condition_stds != 0, condition_stds, 1
            )

        # define a function to return batches
        if convolve_errs:

            def get_batch(sample_key, x, type):
                return self._get_err_samples(sample_key, x, 1, type=type)

        else:

            def get_batch(sample_key, x, type):
                if type == "conditions":
                    return self._get_conditions(x)
                else:
                    return jnp.array(x[columns].to_numpy())

        # get random seed for training loop
        key = random.PRNGKey(batch_seed)

        if verbose:
            print(f"Training {epochs} epochs \nLoss:")

        # save the initial loss
        X = jnp.array(inputs[columns].to_numpy())
        C = self._get_conditions(inputs)
        losses = [loss_fn(model_params, X, C).item()]

        if val_set is not None:
            Xval = jnp.array(val_set[columns].to_numpy())
            Cval = self._get_conditions(val_set)
            val_losses = [loss_fn(model_params, Xval, Cval).item()]

        if verbose:
            if val_set is None:
                print(f"(0) {losses[-1]:.4f}")
            else:
                print(f"(0) {losses[-1]:.4f}  {val_losses[-1]:.4f}")

        # initialize variables for early stopping
        best_loss = jnp.inf
        best_param_vals = model_params
        early_stopping_counter = 0

        # loop through training
        loop = tqdm(range(epochs)) if progress_bar else range(epochs)
        for epoch in loop:
            # new permutation of batches
            permute_key, sample_key, key = random.split(key, num=3)
            idx = random.permutation(permute_key, inputs.shape[0])
            X = inputs.iloc[idx]

            # loop through batches and step optimizer
            for batch_idx in range(0, len(X), batch_size):
                # if sampling from the error distribution, this returns a
                # Gaussian sample of the batch. Else just returns batch as a
                # jax array
                batch = get_batch(
                    sample_key,
                    X.iloc[batch_idx : batch_idx + batch_size],
                    type="data",
                )
                batch_conditions = get_batch(
                    sample_key,
                    X.iloc[batch_idx : batch_idx + batch_size],
                    type="conditions",
                )

                model_params, opt_state = step(
                    model_params,
                    opt_state,
                    batch,
                    batch_conditions,
                )

            # save end-of-epoch training loss
            losses.append(
                loss_fn(
                    model_params,
                    jnp.array(X[columns].to_numpy()),
                    self._get_conditions(X),
                ).item()
            )

            # and validation loss
            if val_set is not None:
                val_losses.append(loss_fn(model_params, Xval, Cval).item())

            # if verbose, print current loss
            if verbose and (
                epoch % max(int(0.05 * epochs), 1) == 0
                or (epoch + 1) == epochs
            ):
                if val_set is None:
                    print(f"({epoch+1}) {losses[-1]:.4f}")
                else:
                    print(
                        f"({epoch+1}) {losses[-1]:.4f}  {val_losses[-1]:.4f}"
                    )

            # if patience provided, we need to check for early stopping
            if patience is not None or best_loss:
                if val_set is None:
                    tracked_losses = losses
                else:
                    tracked_losses = val_losses

                # if loss didn't improve, increase counter
                # and check early stopping criterion
                if tracked_losses[-1] >= best_loss or jnp.isclose(
                    tracked_losses[-1], best_loss
                ):
                    early_stopping_counter += 1

                    # check if the early stopping criterion is met
                    if (
                        patience is not None
                        and early_stopping_counter >= patience
                    ):
                        print(
                            "Early stopping criterion is met.",
                            f"Training stopping after epoch {epoch+1}.",
                        )
                        break
                # if this is the best loss, reset the counter
                else:
                    best_loss = tracked_losses[-1]
                    best_param_vals = model_params
                    early_stopping_counter = 0

            # break if the training loss is NaN
            if not np.isfinite(losses[-1]):
                print(
                    f"Training stopping after epoch {epoch+1}",
                    "because training loss diverged.",
                )
                break

        # update the flow parameters with the final training state
        if best_params:
            self._params = best_param_vals
        else:
            self._params = model_params

        if val_set is None:
            return losses
        else:
            return [losses, val_losses]

__init__(data_columns=None, bijector=None, latent=None, conditional_columns=None, data_error_model=None, condition_error_model=None, autoscale_conditions=True, seed=0, info=None, file=None, _dictionary=None)

Instantiate a normalizing flow.

Note that while all of the init parameters are technically optional, you must provide either data_columns OR file. In addition, if a file is provided, all other parameters must be None.

Parameters:

Name Type Description Default
data_columns Sequence[str]

Tuple, list, or other container of column names. These are the columns the flow expects/produces in DataFrames.

None
bijector Bijector Call

A Bijector call that consists of the bijector InitFunction that initializes the bijector and the tuple of Bijector Info. Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc. If not provided, the bijector can be set later using flow.set_bijector, or by calling flow.train, in which case the default bijector will be used. The default bijector is ShiftBounds -> RollingSplineCoupling, where the range of shift bounds is learned from the training data, and the dimensions of RollingSplineCoupling is inferred. The default bijector assumes that the latent has support [-5, 5] for every dimension.

None
latent distributions.LatentDist

The latent distribution for the normalizing flow. Can be any of the distributions from pzflow.distributions. If not provided, a uniform distribution is used with input_dim = len(data_columns), and B=5.

None
conditional_columns Sequence[str]

Names of columns on which to condition the normalizing flow.

None
data_error_model Callable

A callable that defines the error model for data variables. data_error_model must take key, X, Xerr, nsamples as arguments: - key is a jax rng key, e.g. jax.random.PRNGKey(0) - X is 2D array of data variables, where the order of variables matches the order of the columns in data_columns - Xerr is the corresponding 2D array of errors - nsamples is number of samples to draw from error distribution data_error_model must return an array of samples with the shape (X.shape[0], nsamples, X.shape[1]). If data_error_model is not provided, Gaussian error model assumed.

None
condition_error_model Callable

A callable that defines the error model for conditional variables. condition_error_model must take key, X, Xerr, nsamples, where: - key is a jax rng key, e.g. jax.random.PRNGKey(0) - X is 2D array of conditional variables, where the order of variables matches order of columns in conditional_columns - Xerr is the corresponding 2D array of errors - nsamples is number of samples to draw from error distribution condition_error_model must return array of samples with shape (X.shape[0], nsamples, X.shape[1]). If condition_error_model is not provided, Gaussian error model assumed.

None
autoscale_conditions bool

Sets whether or not conditions are automatically standard scaled when passed to a conditional flow. I recommend you leave as True.

True
seed int

The random seed for initial parameters

0
info Any

An object to attach to the info attribute.

None
file str

Path to file from which to load a pretrained flow. If a file is provided, all other parameters must be None.

None
Source code in pzflow/flow.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def __init__(
    self,
    data_columns: Sequence[str] = None,
    bijector: Tuple[InitFunction, Bijector_Info] = None,
    latent: distributions.LatentDist = None,
    conditional_columns: Sequence[str] = None,
    data_error_model: Callable = None,
    condition_error_model: Callable = None,
    autoscale_conditions: bool = True,
    seed: int = 0,
    info: Any = None,
    file: str = None,
    _dictionary: dict = None,
) -> None:
    """Instantiate a normalizing flow.

    Note that while all of the init parameters are technically optional,
    you must provide either data_columns OR file.
    In addition, if a file is provided, all other parameters must be None.

    Parameters
    ----------
    data_columns : Sequence[str]; optional
        Tuple, list, or other container of column names.
        These are the columns the flow expects/produces in DataFrames.
    bijector : Bijector Call; optional
        A Bijector call that consists of the bijector InitFunction that
        initializes the bijector and the tuple of Bijector Info.
        Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.
        If not provided, the bijector can be set later using
        flow.set_bijector, or by calling flow.train, in which case the
        default bijector will be used. The default bijector is
        ShiftBounds -> RollingSplineCoupling, where the range of shift
        bounds is learned from the training data, and the dimensions of
        RollingSplineCoupling is inferred. The default bijector assumes
        that the latent has support [-5, 5] for every dimension.
    latent : distributions.LatentDist; optional
        The latent distribution for the normalizing flow. Can be any of
        the distributions from pzflow.distributions. If not provided,
        a uniform distribution is used with input_dim = len(data_columns),
        and B=5.
    conditional_columns : Sequence[str]; optional
        Names of columns on which to condition the normalizing flow.
    data_error_model : Callable; optional
        A callable that defines the error model for data variables.
        data_error_model must take key, X, Xerr, nsamples as arguments:
            - key is a jax rng key, e.g. jax.random.PRNGKey(0)
            - X is 2D array of data variables, where the order of variables
                matches the order of the columns in data_columns
            - Xerr is the corresponding 2D array of errors
            - nsamples is number of samples to draw from error distribution
        data_error_model must return an array of samples with the shape
        (X.shape[0], nsamples, X.shape[1]).
        If data_error_model is not provided, Gaussian error model assumed.
    condition_error_model : Callable; optional
        A callable that defines the error model for conditional variables.
        condition_error_model must take key, X, Xerr, nsamples, where:
            - key is a jax rng key, e.g. jax.random.PRNGKey(0)
            - X is 2D array of conditional variables, where the order of
                variables matches order of columns in conditional_columns
            - Xerr is the corresponding 2D array of errors
            - nsamples is number of samples to draw from error distribution
        condition_error_model must return array of samples with shape
        (X.shape[0], nsamples, X.shape[1]).
        If condition_error_model is not provided, Gaussian error model
        assumed.
    autoscale_conditions : bool; default=True
        Sets whether or not conditions are automatically standard scaled
        when passed to a conditional flow. I recommend you leave as True.
    seed : int; default=0
        The random seed for initial parameters
    info : Any; optional
        An object to attach to the info attribute.
    file : str; optional
        Path to file from which to load a pretrained flow.
        If a file is provided, all other parameters must be None.
    """

    # validate parameters
    if data_columns is None and file is None and _dictionary is None:
        raise ValueError("You must provide data_columns OR file.")
    if any(
        (
            data_columns is not None,
            bijector is not None,
            conditional_columns is not None,
            latent is not None,
            data_error_model is not None,
            condition_error_model is not None,
            info is not None,
        )
    ):
        if file is not None:
            raise ValueError(
                "If providing a file, please do not provide any other parameters."
            )
        if _dictionary is not None:
            raise ValueError(
                "If providing a dictionary, please do not provide any other parameters."
            )
    if file is not None and _dictionary is not None:
        raise ValueError("Only provide file or _dictionary, not both.")

    # if file or dictionary is provided, load everything from it
    if file is not None or _dictionary is not None:
        save_dict = self._save_dict()
        if file is not None:
            with open(file, "rb") as handle:
                save_dict.update(pickle.load(handle))
        else:
            save_dict.update(_dictionary)

        if save_dict["class"] != self.__class__.__name__:
            raise TypeError(
                f"This save file isn't a {self.__class__.__name__}. "
                f"It is a {save_dict['class']}"
            )

        # load columns and dimensions
        self.data_columns = save_dict["data_columns"]
        self.conditional_columns = save_dict["conditional_columns"]
        self._input_dim = len(self.data_columns)
        self.info = save_dict["info"]

        # load the latent distribution
        self._latent_info = save_dict["latent_info"]
        self.latent = getattr(distributions, self._latent_info[0])(
            *self._latent_info[1]
        )

        # load the error models
        self.data_error_model = save_dict["data_error_model"]
        self.condition_error_model = save_dict["condition_error_model"]

        # load the bijector
        self._bijector_info = save_dict["bijector_info"]
        if self._bijector_info is not None:
            init_fun, _ = build_bijector_from_info(self._bijector_info)
            _, self._forward, self._inverse = init_fun(
                random.PRNGKey(0), self._input_dim
            )
        self._params = save_dict["params"]

        # load the conditional means and stds
        self._condition_means = save_dict["condition_means"]
        self._condition_stds = save_dict["condition_stds"]

        # set whether or not to automatically standard scale any
        # conditions passed to the normalizing flow
        self._autoscale_conditions = save_dict["autoscale_conditions"]

    # if no file is provided, use provided parameters
    else:
        self.data_columns = tuple(data_columns)
        self._input_dim = len(self.data_columns)
        self.info = info

        if conditional_columns is None:
            self.conditional_columns = None
            self._condition_means = None
            self._condition_stds = None
        else:
            self.conditional_columns = tuple(conditional_columns)
            self._condition_means = jnp.zeros(
                len(self.conditional_columns)
            )
            self._condition_stds = jnp.ones(len(self.conditional_columns))

        # set whether or not to automatically standard scale any
        # conditions passed to the normalizing flow
        self._autoscale_conditions = autoscale_conditions

        # set up the latent distribution
        if latent is None:
            self.latent = distributions.Uniform(self._input_dim, 5)
        else:
            self.latent = latent
        self._latent_info = self.latent.info

        # make sure the latent distribution and data_columns have the
        # same number of dimensions
        if self.latent.input_dim != len(data_columns):
            raise ValueError(
                f"The latent distribution has {self.latent.input_dim} "
                f"dimensions, but data_columns has {len(data_columns)} "
                "dimensions. They must match!"
            )

        # set up the error models
        if data_error_model is None:
            self.data_error_model = gaussian_error_model
        else:
            self.data_error_model = data_error_model
        if condition_error_model is None:
            self.condition_error_model = gaussian_error_model
        else:
            self.condition_error_model = condition_error_model

        # set up the bijector
        if bijector is not None:
            self.set_bijector(bijector, seed=seed)
        # if no bijector was provided, set bijector_info to None
        else:
            self._bijector_info = None

log_prob(inputs, err_samples=None, seed=None)

Calculates log probability density of inputs.

Parameters:

Name Type Description Default
inputs pd.DataFrame

Input data for which log probability density is calculated. Every column in self.data_columns must be present. If self.conditional_columns is not None, those must be present as well. If other columns are present, they are ignored.

required
err_samples int

Number of samples from the error distribution to average over for the log_prob calculation. If provided, Gaussian errors are assumed, and method will look for error columns in inputs. Error columns must end in _err. E.g. the error column for the variable u must be u_err. Zero error assumed for any missing error columns.

None
seed int

Random seed for drawing the samples with Gaussian errors.

None

Returns:

Type Description
jnp.ndarray

Device array of shape (inputs.shape[0],).

Source code in pzflow/flow.py
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
def log_prob(
    self, inputs: pd.DataFrame, err_samples: int = None, seed: int = None
) -> jnp.ndarray:
    """Calculates log probability density of inputs.

    Parameters
    ----------
    inputs : pd.DataFrame
        Input data for which log probability density is calculated.
        Every column in self.data_columns must be present.
        If self.conditional_columns is not None, those must be present
        as well. If other columns are present, they are ignored.
    err_samples : int; default=None
        Number of samples from the error distribution to average over for
        the log_prob calculation. If provided, Gaussian errors are assumed,
        and method will look for error columns in `inputs`. Error columns
        must end in `_err`. E.g. the error column for the variable `u` must
        be `u_err`. Zero error assumed for any missing error columns.
    seed : int; default=None
        Random seed for drawing the samples with Gaussian errors.

    Returns
    -------
    jnp.ndarray
        Device array of shape (inputs.shape[0],).
    """

    # check that the bijector exists
    self._check_bijector()

    if err_samples is None:
        # convert data to an array with columns ordered
        columns = list(self.data_columns)
        X = jnp.array(inputs[columns].to_numpy())
        # get conditions
        conditions = self._get_conditions(inputs)
        # calculate log_prob
        return self._log_prob(self._params, X, conditions)

    else:
        # validate nsamples
        assert isinstance(
            err_samples, int
        ), "err_samples must be a positive integer."
        assert err_samples > 0, "err_samples must be a positive integer."
        # get Gaussian samples
        seed = np.random.randint(1e18) if seed is None else seed
        key = random.PRNGKey(seed)
        X = self._get_err_samples(key, inputs, err_samples, type="data")
        C = self._get_err_samples(
            key, inputs, err_samples, type="conditions"
        )
        # calculate log_probs
        log_probs = self._log_prob(self._params, X, C)
        probs = jnp.exp(log_probs.reshape(-1, err_samples))
        return jnp.log(probs.mean(axis=1))

posterior(inputs, column, grid, marg_rules=None, normalize=True, err_samples=None, seed=None, batch_size=None, nan_to_zero=True)

Calculates posterior distributions for the provided column.

Calculates the conditional posterior distribution, assuming the data values in the other columns of the DataFrame.

Parameters:

Name Type Description Default
inputs pd.DataFrame

Data on which the posterior distributions are conditioned. Must have columns matching self.data_columns, except for the column specified for the posterior (see below).

required
column str

Name of the column for which the posterior distribution is calculated. Must be one of the columns in self.data_columns. However, whether or not this column is one of the columns in inputs is irrelevant.

required
grid jnp.ndarray

Grid on which to calculate the posterior.

required
marg_rules dict

Dictionary with rules for marginalizing over missing variables. The dictionary must contain the key "flag", which gives the flag that indicates a missing value. E.g. if missing values are given the value 99, the dictionary should contain {"flag": 99}. The dictionary must also contain {"name": callable} for any variables that will need to be marginalized over, where name is the name of the variable, and callable is a callable that takes the row of variables nad returns a grid over which to marginalize the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}. Note: the callable for a given name must always return an array of the same length, regardless of the input row.

None
err_samples int

Number of samples from the error distribution to average over for the posterior calculation. If provided, Gaussian errors are assumed, and method will look for error columns in inputs. Error columns must end in _err. E.g. the error column for the variable u must be u_err. Zero error assumed for any missing error columns.

None
seed int

Random seed for drawing the samples with Gaussian errors.

None
batch_size int

Size of batches in which to calculate posteriors. If None, all posteriors are calculated simultaneously. Simultaneous calculation is faster, but memory intensive for large data sets.

None
normalize boolean

Whether to normalize the posterior so that it integrates to 1.

True
nan_to_zero bool

Whether to convert NaN's to zero probability in the final pdfs.

True

Returns:

Type Description
jnp.ndarray

Device array of shape (inputs.shape[0], grid.size).

Source code in pzflow/flow.py
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
def posterior(
    self,
    inputs: pd.DataFrame,
    column: str,
    grid: jnp.ndarray,
    marg_rules: dict = None,
    normalize: bool = True,
    err_samples: int = None,
    seed: int = None,
    batch_size: int = None,
    nan_to_zero: bool = True,
) -> jnp.ndarray:
    """Calculates posterior distributions for the provided column.

    Calculates the conditional posterior distribution, assuming the
    data values in the other columns of the DataFrame.

    Parameters
    ----------
    inputs : pd.DataFrame
        Data on which the posterior distributions are conditioned.
        Must have columns matching self.data_columns, *except*
        for the column specified for the posterior (see below).
    column : str
        Name of the column for which the posterior distribution
        is calculated. Must be one of the columns in self.data_columns.
        However, whether or not this column is one of the columns in
        `inputs` is irrelevant.
    grid : jnp.ndarray
        Grid on which to calculate the posterior.
    marg_rules : dict; optional
        Dictionary with rules for marginalizing over missing variables.
        The dictionary must contain the key "flag", which gives the flag
        that indicates a missing value. E.g. if missing values are given
        the value 99, the dictionary should contain {"flag": 99}.
        The dictionary must also contain {"name": callable} for any
        variables that will need to be marginalized over, where name is
        the name of the variable, and callable is a callable that takes
        the row of variables nad returns a grid over which to marginalize
        the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}.
        Note: the callable for a given name must *always* return an array
        of the same length, regardless of the input row.
    err_samples : int; default=None
        Number of samples from the error distribution to average over for
        the posterior calculation. If provided, Gaussian errors are assumed,
        and method will look for error columns in `inputs`. Error columns
        must end in `_err`. E.g. the error column for the variable `u` must
        be `u_err`. Zero error assumed for any missing error columns.
    seed : int; default=None
        Random seed for drawing the samples with Gaussian errors.
    batch_size : int; default=None
        Size of batches in which to calculate posteriors. If None, all
        posteriors are calculated simultaneously. Simultaneous calculation
        is faster, but memory intensive for large data sets.
    normalize : boolean; default=True
        Whether to normalize the posterior so that it integrates to 1.
    nan_to_zero : bool; default=True
        Whether to convert NaN's to zero probability in the final pdfs.

    Returns
    -------
    jnp.ndarray
        Device array of shape (inputs.shape[0], grid.size).
    """

    # check that the bijector exists
    self._check_bijector()

    # get the index of the provided column, and remove it from the list
    columns = list(self.data_columns)
    idx = columns.index(column)
    columns.remove(column)

    nrows = inputs.shape[0]
    batch_size = nrows if batch_size is None else batch_size

    # make sure indices run 0 -> nrows
    inputs = inputs.reset_index(drop=True)

    if err_samples is not None:
        # validate nsamples
        assert isinstance(
            err_samples, int
        ), "err_samples must be a positive integer."
        assert err_samples > 0, "err_samples must be a positive integer."
        # set the seed
        seed = np.random.randint(1e18) if seed is None else seed
        key = random.PRNGKey(seed)

    # empty array to hold pdfs
    pdfs = jnp.zeros((nrows, len(grid)))

    # if marginalization rules were passed, we will loop over the rules
    # and repeatedly call this method
    if marg_rules is not None:
        # if the flag is NaN, we must use jnp.isnan to check for flags
        if np.isnan(marg_rules["flag"]):

            def check_flags(data):
                return np.isnan(data)

        # else we use jnp.isclose to check for flags
        else:

            def check_flags(data):
                return np.isclose(data, marg_rules["flag"])

        # first calculate pdfs for unflagged rows
        unflagged_idx = inputs[
            ~check_flags(inputs[columns]).any(axis=1)
        ].index.tolist()
        unflagged_pdfs = self.posterior(
            inputs=inputs.iloc[unflagged_idx],
            column=column,
            grid=grid,
            err_samples=err_samples,
            seed=seed,
            batch_size=batch_size,
            normalize=False,
            nan_to_zero=nan_to_zero,
        )

        # save these pdfs in the big array
        pdfs = pdfs.at[unflagged_idx, :].set(
            unflagged_pdfs,
            indices_are_sorted=True,
            unique_indices=True,
        )

        # we will keep track of all the rows we've already calculated
        # posteriors for
        already_done = unflagged_idx

        # now we will loop over the rules in marg_rules
        for name, rule in marg_rules.items():
            # ignore the flag, because that's not a column in the data
            if name == "flag":
                continue

            # get the list of new rows for which we need to calculate posteriors
            flagged_idx = inputs[check_flags(inputs[name])].index.tolist()
            flagged_idx = list(set(flagged_idx).difference(already_done))

            # if flagged_idx is empty, move on!
            if len(flagged_idx) == 0:
                continue

            # get the marginalization grid for each row
            marg_grids = (
                inputs.iloc[flagged_idx]
                .apply(rule, axis=1, result_type="expand")
                .to_numpy()
            )

            # make a new data frame with the marginalization grids replacing
            # the values of the flag in the column
            marg_inputs = pd.DataFrame(
                np.repeat(
                    inputs.iloc[flagged_idx].to_numpy(),
                    marg_grids.shape[1],
                    axis=0,
                ),
                columns=inputs.columns,
            )
            marg_inputs[name] = marg_grids.reshape(marg_inputs.shape[0], 1)

            # remove the error column if it's present
            marg_inputs.drop(
                f"{name}_err", axis=1, inplace=True, errors="ignore"
            )

            # calculate posteriors for these
            marg_pdfs = self.posterior(
                inputs=marg_inputs,
                column=column,
                grid=grid,
                marg_rules=marg_rules,
                err_samples=err_samples,
                seed=seed,
                batch_size=batch_size,
                normalize=False,
                nan_to_zero=nan_to_zero,
            )

            # sum over the marginalized dimension
            marg_pdfs = marg_pdfs.reshape(
                len(flagged_idx), marg_grids.shape[1], grid.size
            )
            marg_pdfs = marg_pdfs.sum(axis=1)

            # save the new pdfs in the big array
            pdfs = pdfs.at[flagged_idx, :].set(
                marg_pdfs,
                indices_are_sorted=True,
                unique_indices=True,
            )

            # add these flagged indices to the list of rows already done
            already_done += flagged_idx

    # now for the main posterior calculation loop
    else:
        # loop through batches
        for batch_idx in range(0, nrows, batch_size):
            # get the data batch
            # and, if this is a conditional flow, the correpsonding conditions
            batch = inputs.iloc[batch_idx : batch_idx + batch_size]

            # if not drawing samples, just grab batch and conditions
            if err_samples is None:
                conditions = self._get_conditions(batch)
                batch = jnp.array(batch[columns].to_numpy())
            # if only drawing condition samples...
            elif len(self.data_columns) == 1:
                conditions = self._get_err_samples(
                    key, batch, err_samples, type="conditions"
                )
                batch = jnp.repeat(
                    batch[columns].to_numpy(), err_samples, axis=0
                )
            # if drawing data and condition samples...
            else:
                conditions = self._get_err_samples(
                    key, batch, err_samples, type="conditions"
                )
                batch = self._get_err_samples(
                    key, batch, err_samples, skip=column, type="data"
                )

            # make a new copy of each row for each value of the column
            # for which we are calculating the posterior
            batch = jnp.hstack(
                (
                    jnp.repeat(
                        batch[:, :idx],
                        len(grid),
                        axis=0,
                    ),
                    jnp.tile(grid, len(batch))[:, None],
                    jnp.repeat(
                        batch[:, idx:],
                        len(grid),
                        axis=0,
                    ),
                )
            )

            # make similar copies of the conditions
            conditions = jnp.repeat(conditions, len(grid), axis=0)

            # calculate probability densities
            log_prob = self._log_prob(
                self._params, batch, conditions
            ).reshape((-1, len(grid)))
            prob = jnp.exp(log_prob)
            # if we were Gaussian sampling, average over the samples
            if err_samples is not None:
                prob = prob.reshape(-1, err_samples, len(grid))
                prob = prob.mean(axis=1)
            # add the pdfs to the bigger list
            pdfs = pdfs.at[batch_idx : batch_idx + batch_size, :].set(
                prob,
                indices_are_sorted=True,
                unique_indices=True,
            )

    if normalize:
        # normalize so they integrate to one
        pdfs = pdfs / jnp.trapz(y=pdfs, x=grid).reshape(-1, 1)
    if nan_to_zero:
        # set NaN's equal to zero probability
        pdfs = jnp.nan_to_num(pdfs, nan=0.0)
    return pdfs

sample(nsamples=1, conditions=None, save_conditions=True, seed=None)

Returns samples from the normalizing flow.

Parameters:

Name Type Description Default
nsamples int

The number of samples to be returned.

1
conditions pd.DataFrame

If this is a conditional flow, you must pass conditions for each sample. nsamples will be drawn for each row in conditions.

None
save_conditions bool

If true, conditions will be saved in the DataFrame of samples that is returned.

True
seed int

Sets the random seed for the samples.

None

Returns:

Type Description
pd.DataFrame

Pandas DataFrame of samples.

Source code in pzflow/flow.py
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
def sample(
    self,
    nsamples: int = 1,
    conditions: pd.DataFrame = None,
    save_conditions: bool = True,
    seed: int = None,
) -> pd.DataFrame:
    """Returns samples from the normalizing flow.

    Parameters
    ----------
    nsamples : int; default=1
        The number of samples to be returned.
    conditions : pd.DataFrame; optional
        If this is a conditional flow, you must pass conditions for
        each sample. nsamples will be drawn for each row in conditions.
    save_conditions : bool; default=True
        If true, conditions will be saved in the DataFrame of samples
        that is returned.
    seed : int; optional
        Sets the random seed for the samples.

    Returns
    -------
    pd.DataFrame
        Pandas DataFrame of samples.
    """

    # check that the bijector exists
    self._check_bijector()

    # validate nsamples
    assert isinstance(
        nsamples, int
    ), "nsamples must be a positive integer."
    assert nsamples > 0, "nsamples must be a positive integer."

    if self.conditional_columns is not None and conditions is None:
        raise ValueError(
            f"Must provide the following conditions\n{self.conditional_columns}"
        )

    # if this isn't a conditional flow, get empty conditions
    if self.conditional_columns is None:
        conditions = jnp.zeros((nsamples, 1))
    # otherwise get conditions and make `nsamples` copies of each
    else:
        conditions_idx = list(conditions.index)
        conditions = self._get_conditions(conditions)
        conditions_idx = np.repeat(conditions_idx, nsamples)
        conditions = jnp.repeat(conditions, nsamples, axis=0)

    # draw from latent distribution
    u = self.latent.sample(self._params[0], conditions.shape[0], seed)
    # take the inverse back to the data distribution
    x = self._inverse(self._params[1], u, conditions=conditions)[0]
    # if not conditional, this is all we need
    if self.conditional_columns is None:
        x = pd.DataFrame(np.array(x), columns=self.data_columns)
    # but if conditional
    else:
        if save_conditions:
            # unscale the conditions
            conditions = (
                conditions * self._condition_stds + self._condition_means
            )
            x = pd.DataFrame(
                np.array(jnp.hstack((x, conditions))),
                columns=self.data_columns + self.conditional_columns,
            ).set_index(conditions_idx)
        else:
            # reindex according to the conditions
            x = pd.DataFrame(
                np.array(x), columns=self.data_columns
            ).set_index(conditions_idx)

    # return the samples!
    return x

save(file)

Saves the flow to a file.

Pickles the flow and saves it to a file that can be passed as the file argument during flow instantiation.

WARNING: Currently, this method only works for bijectors that are implemented in the bijectors module. If you want to save a flow with a custom bijector, you either need to add the bijector to that module, or handle the saving and loading on your end.

Parameters:

Name Type Description Default
file str

Path to where the flow will be saved. Extension .pkl will be appended if not already present.

required
Source code in pzflow/flow.py
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
def save(self, file: str) -> None:
    """Saves the flow to a file.

    Pickles the flow and saves it to a file that can be passed as
    the `file` argument during flow instantiation.

    WARNING: Currently, this method only works for bijectors that are
    implemented in the `bijectors` module. If you want to save a flow
    with a custom bijector, you either need to add the bijector to that
    module, or handle the saving and loading on your end.

    Parameters
    ----------
    file : str
        Path to where the flow will be saved.
        Extension `.pkl` will be appended if not already present.
    """
    save_dict = self._save_dict()

    with open(file, "wb") as handle:
        pickle.dump(save_dict, handle, recurse=True)

set_bijector(bijector, params=None, seed=0)

Set the bijector.

Parameters:

Name Type Description Default
bijector Bijector Call

A Bijector call that consists of the bijector InitFunction that initializes the bijector and the tuple of Bijector Info. Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.

required
params Pytree

A Pytree of bijector parameters. If not provided, the bijector will be initialized with random parameters.

None
seed int

A random seed for initializing the bijector with random parameters.

0
Source code in pzflow/flow.py
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def set_bijector(
    self,
    bijector: Tuple[InitFunction, Bijector_Info],
    params: Pytree = None,
    seed: int = 0,
) -> None:
    """Set the bijector.

    Parameters
    ----------
    bijector : Bijector Call
        A Bijector call that consists of the bijector InitFunction that
        initializes the bijector and the tuple of Bijector Info.
        Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.
    params : Pytree; optional
        A Pytree of bijector parameters. If not provided, the bijector
        will be initialized with random parameters.
    seed: int; default=0
        A random seed for initializing the bijector with random parameters.
    """

    # set up the bijector
    init_fun, self._bijector_info = bijector
    bijector_params, self._forward, self._inverse = init_fun(
        random.PRNGKey(seed), self._input_dim
    )

    # check if params were passed
    bijector_params = params if params is not None else bijector_params

    # save the bijector params along with the latent params
    self._params = (self.latent._params, bijector_params)

train(inputs, val_set=None, epochs=100, batch_size=1024, optimizer=None, loss_fn=None, convolve_errs=False, patience=None, best_params=True, seed=0, verbose=False, progress_bar=False)

Trains the normalizing flow on the provided inputs.

Parameters:

Name Type Description Default
inputs pd.DataFrame

Data on which to train the normalizing flow. Must have columns matching self.data_columns.

required
val_set pd.DataFrame

Validation set, of same format as inputs. If provided, validation loss will be calculated at the end of each epoch.

None
epochs int

Number of epochs to train.

100
batch_size int

Batch size for training.

1024
optimizer optax optimizer

An optimizer from Optax. default = optax.adam(learning_rate=1e-3) see https://optax.readthedocs.io/en/latest/index.html for more.

None
loss_fn Callable

A function to calculate the loss: loss = loss_fn(params, x). If not provided, will be -mean(log_prob).

None
convolve_errs bool

Whether to draw new data from the error distributions during each epoch of training. Method will look for error columns in inputs. Error columns must end in _err. E.g. the error column for the variable u must be u_err. Zero error assumed for any missing error columns. The error distribution is set during flow instantiation.

False
patience int

Factor that controls early stopping. Training will stop if the loss doesn't decrease for this number of epochs. Note if a validation set is provided, the validation loss is used.

None
best_params bool

Whether to use the params from the epoch with the lowest loss. Note if a validation set is provided, the epoch with the lowest validation loss is chosen. If False, the params from the final epoch are saved.

True
seed int

A random seed to control the batching and the (optional) error sampling and creating the default bijector (the latter only happens if you didn't set up the bijector during Flow instantiation).

0
verbose bool

If true, print the training loss every 5% of epochs.

False
progress_bar bool

If true, display a tqdm progress bar during training.

False

Returns:

Type Description
list

List of training losses from every epoch. If no val_set provided, these are just training losses. If val_set is provided, then the first element is the list of training losses, while the second is the list of validation losses.

Source code in pzflow/flow.py
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
def train(
    self,
    inputs: pd.DataFrame,
    val_set: pd.DataFrame = None,
    epochs: int = 100,
    batch_size: int = 1024,
    optimizer: Callable = None,
    loss_fn: Callable = None,
    convolve_errs: bool = False,
    patience: int = None,
    best_params: bool = True,
    seed: int = 0,
    verbose: bool = False,
    progress_bar: bool = False,
) -> list:
    """Trains the normalizing flow on the provided inputs.

    Parameters
    ----------
    inputs : pd.DataFrame
        Data on which to train the normalizing flow.
        Must have columns matching `self.data_columns`.
    val_set : pd.DataFrame; default=None
        Validation set, of same format as inputs. If provided,
        validation loss will be calculated at the end of each epoch.
    epochs : int; default=100
        Number of epochs to train.
    batch_size : int; default=1024
        Batch size for training.
    optimizer : optax optimizer
        An optimizer from Optax. default = optax.adam(learning_rate=1e-3)
        see https://optax.readthedocs.io/en/latest/index.html for more.
    loss_fn : Callable; optional
        A function to calculate the loss: `loss = loss_fn(params, x)`.
        If not provided, will be `-mean(log_prob)`.
    convolve_errs : bool; default=False
        Whether to draw new data from the error distributions during
        each epoch of training. Method will look for error columns in
        `inputs`. Error columns must end in `_err`. E.g. the error column
        for the variable `u` must be `u_err`. Zero error assumed for
        any missing error columns. The error distribution is set during
        flow instantiation.
    patience : int; optional
        Factor that controls early stopping. Training will stop if the
        loss doesn't decrease for this number of epochs. Note if a
        validation set is provided, the validation loss is used.
    best_params : bool; default=True
        Whether to use the params from the epoch with the lowest loss.
        Note if a validation set is provided, the epoch with the lowest
        validation loss is chosen. If False, the params from the final
        epoch are saved.
    seed : int; default=0
        A random seed to control the batching and the (optional)
        error sampling and creating the default bijector (the latter
        only happens if you didn't set up the bijector during Flow
        instantiation).
    verbose : bool; default=False
        If true, print the training loss every 5% of epochs.
    progress_bar : bool; default=False
        If true, display a tqdm progress bar during training.

    Returns
    -------
    list
        List of training losses from every epoch. If no val_set provided,
        these are just training losses. If val_set is provided, then the
        first element is the list of training losses, while the second is
        the list of validation losses.
    """

    # split the seed
    rng = np.random.default_rng(seed)
    batch_seed, bijector_seed = rng.integers(1e9, size=2)

    # if the bijector is None, set the default bijector
    if self._bijector_info is None:
        self._set_default_bijector(inputs, seed=bijector_seed)

    # validate epochs
    if not isinstance(epochs, int) or epochs <= 0:
        raise ValueError("epochs must be a positive integer.")

    # if no loss_fn is provided, use the default loss function
    if loss_fn is None:

        @jit
        def loss_fn(params, x, c):
            return -jnp.mean(self._log_prob(params, x, c))

    # initialize the optimizer
    optimizer = (
        optax.adam(learning_rate=1e-3) if optimizer is None else optimizer
    )
    opt_state = optimizer.init(self._params)

    # pull out the model parameters
    model_params = self._params

    # define the training step function
    @jit
    def step(params, opt_state, x, c):
        gradients = grad(loss_fn)(params, x, c)
        updates, opt_state = optimizer.update(gradients, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, opt_state

    # get list of data columns
    columns = list(self.data_columns)

    # if this is a conditional flow, and autoscale_conditions == True
    # save the means and stds of the conditional columns
    if self.conditional_columns is not None and self._autoscale_conditions:
        self._condition_means = jnp.array(
            inputs[list(self.conditional_columns)].to_numpy().mean(axis=0)
        )
        condition_stds = jnp.array(
            inputs[list(self.conditional_columns)].to_numpy().std(axis=0)
        )
        self._condition_stds = jnp.where(
            condition_stds != 0, condition_stds, 1
        )

    # define a function to return batches
    if convolve_errs:

        def get_batch(sample_key, x, type):
            return self._get_err_samples(sample_key, x, 1, type=type)

    else:

        def get_batch(sample_key, x, type):
            if type == "conditions":
                return self._get_conditions(x)
            else:
                return jnp.array(x[columns].to_numpy())

    # get random seed for training loop
    key = random.PRNGKey(batch_seed)

    if verbose:
        print(f"Training {epochs} epochs \nLoss:")

    # save the initial loss
    X = jnp.array(inputs[columns].to_numpy())
    C = self._get_conditions(inputs)
    losses = [loss_fn(model_params, X, C).item()]

    if val_set is not None:
        Xval = jnp.array(val_set[columns].to_numpy())
        Cval = self._get_conditions(val_set)
        val_losses = [loss_fn(model_params, Xval, Cval).item()]

    if verbose:
        if val_set is None:
            print(f"(0) {losses[-1]:.4f}")
        else:
            print(f"(0) {losses[-1]:.4f}  {val_losses[-1]:.4f}")

    # initialize variables for early stopping
    best_loss = jnp.inf
    best_param_vals = model_params
    early_stopping_counter = 0

    # loop through training
    loop = tqdm(range(epochs)) if progress_bar else range(epochs)
    for epoch in loop:
        # new permutation of batches
        permute_key, sample_key, key = random.split(key, num=3)
        idx = random.permutation(permute_key, inputs.shape[0])
        X = inputs.iloc[idx]

        # loop through batches and step optimizer
        for batch_idx in range(0, len(X), batch_size):
            # if sampling from the error distribution, this returns a
            # Gaussian sample of the batch. Else just returns batch as a
            # jax array
            batch = get_batch(
                sample_key,
                X.iloc[batch_idx : batch_idx + batch_size],
                type="data",
            )
            batch_conditions = get_batch(
                sample_key,
                X.iloc[batch_idx : batch_idx + batch_size],
                type="conditions",
            )

            model_params, opt_state = step(
                model_params,
                opt_state,
                batch,
                batch_conditions,
            )

        # save end-of-epoch training loss
        losses.append(
            loss_fn(
                model_params,
                jnp.array(X[columns].to_numpy()),
                self._get_conditions(X),
            ).item()
        )

        # and validation loss
        if val_set is not None:
            val_losses.append(loss_fn(model_params, Xval, Cval).item())

        # if verbose, print current loss
        if verbose and (
            epoch % max(int(0.05 * epochs), 1) == 0
            or (epoch + 1) == epochs
        ):
            if val_set is None:
                print(f"({epoch+1}) {losses[-1]:.4f}")
            else:
                print(
                    f"({epoch+1}) {losses[-1]:.4f}  {val_losses[-1]:.4f}"
                )

        # if patience provided, we need to check for early stopping
        if patience is not None or best_loss:
            if val_set is None:
                tracked_losses = losses
            else:
                tracked_losses = val_losses

            # if loss didn't improve, increase counter
            # and check early stopping criterion
            if tracked_losses[-1] >= best_loss or jnp.isclose(
                tracked_losses[-1], best_loss
            ):
                early_stopping_counter += 1

                # check if the early stopping criterion is met
                if (
                    patience is not None
                    and early_stopping_counter >= patience
                ):
                    print(
                        "Early stopping criterion is met.",
                        f"Training stopping after epoch {epoch+1}.",
                    )
                    break
            # if this is the best loss, reset the counter
            else:
                best_loss = tracked_losses[-1]
                best_param_vals = model_params
                early_stopping_counter = 0

        # break if the training loss is NaN
        if not np.isfinite(losses[-1]):
            print(
                f"Training stopping after epoch {epoch+1}",
                "because training loss diverged.",
            )
            break

    # update the flow parameters with the final training state
    if best_params:
        self._params = best_param_vals
    else:
        self._params = model_params

    if val_set is None:
        return losses
    else:
        return [losses, val_losses]

flowEnsemble

Define FlowEnsemble object that holds an ensemble of normalizing flows.

FlowEnsemble

An ensemble of normalizing flows.

Attributes:

Name Type Description
data_columns tuple

List of DataFrame columns that the flows expect/produce.

conditional_columns tuple

List of DataFrame columns on which the flows are conditioned.

latent distributions.LatentDist

The latent distribution of the normalizing flows. Has it's own sample and log_prob methods.

data_error_model Callable

The error model for the data variables. See the docstring of init for more details.

condition_error_model Callable

The error model for the conditional variables. See the docstring of init for more details.

info Any

Object containing any kind of info included with the ensemble. Often Reverse the data the flows are trained on.

Source code in pzflow/flowEnsemble.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
class FlowEnsemble:
    """An ensemble of normalizing flows.

    Attributes
    ----------
    data_columns : tuple
        List of DataFrame columns that the flows expect/produce.
    conditional_columns : tuple
        List of DataFrame columns on which the flows are conditioned.
    latent: distributions.LatentDist
        The latent distribution of the normalizing flows.
        Has it's own sample and log_prob methods.
    data_error_model : Callable
        The error model for the data variables. See the docstring of
        __init__ for more details.
    condition_error_model : Callable
        The error model for the conditional variables. See the docstring
        of __init__ for more details.
    info : Any
        Object containing any kind of info included with the ensemble.
        Often Reverse the data the flows are trained on.
    """

    def __init__(
        self,
        data_columns: Sequence[str] = None,
        bijector: Tuple[InitFunction, Bijector_Info] = None,
        latent: distributions.LatentDist = None,
        conditional_columns: Sequence[str] = None,
        data_error_model: Callable = None,
        condition_error_model: Callable = None,
        autoscale_conditions: bool = True,
        N: int = 1,
        info: Any = None,
        file: str = None,
    ) -> None:
        """Instantiate an ensemble of normalizing flows.

        Note that while all of the init parameters are technically optional,
        you must provide either data_columns and bijector OR file.
        In addition, if a file is provided, all other parameters must be None.

        Parameters
        ----------
        data_columns : Sequence[str]; optional
            Tuple, list, or other container of column names.
            These are the columns the flows expect/produce in DataFrames.
        bijector : Bijector Call; optional
            A Bijector call that consists of the bijector InitFunction that
            initializes the bijector and the tuple of Bijector Info.
            Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.
            If not provided, the bijector can be set later using
            flow.set_bijector, or by calling flow.train, in which case the
            default bijector will be used. The default bijector is
            ShiftBounds -> RollingSplineCoupling, where the range of shift
            bounds is learned from the training data, and the dimensions of
            RollingSplineCoupling is inferred. The default bijector assumes
            that the latent has support [-5, 5] for every dimension.
        latent : distributions.LatentDist; optional
            The latent distribution for the normalizing flow. Can be any of
            the distributions from pzflow.distributions. If not provided,
            a uniform distribution is used with input_dim = len(data_columns),
            and B=5.
        conditional_columns : Sequence[str]; optional
            Names of columns on which to condition the normalizing flows.
        data_error_model : Callable; optional
            A callable that defines the error model for data variables.
            data_error_model must take key, X, Xerr, nsamples as arguments:
                - key is a jax rng key, e.g. jax.random.PRNGKey(0)
                - X is 2D array of data variables, where the order of variables
                    matches the order of the columns in data_columns
                - Xerr is the corresponding 2D array of errors
                - nsamples is number of samples to draw from error distribution
            data_error_model must return an array of samples with the shape
            (X.shape[0], nsamples, X.shape[1]).
            If data_error_model is not provided, Gaussian error model assumed.
        condition_error_model : Callable; optional
            A callable that defines the error model for conditional variables.
            condition_error_model must take key, X, Xerr, nsamples, where:
                - key is a jax rng key, e.g. jax.random.PRNGKey(0)
                - X is 2D array of conditional variables, where the order of
                    variables matches order of columns in conditional_columns
                - Xerr is the corresponding 2D array of errors
                - nsamples is number of samples to draw from error distribution
            condition_error_model must return array of samples with shape
            (X.shape[0], nsamples, X.shape[1]).
            If condition_error_model is not provided, Gaussian error model
            assumed.
        autoscale_conditions : bool; default=True
            Sets whether or not conditions are automatically standard scaled
            when passed to a conditional flow. I recommend you leave as True.
        N : int; default=1
            The number of flows in the ensemble.
        info : Any; optional
            An object to attach to the info attribute.
        file : str; optional
            Path to file from which to load a pretrained flow ensemble.
            If a file is provided, all other parameters must be None.
        """

        # validate parameters
        if data_columns is None and file is None:
            raise ValueError("You must provide data_columns OR file.")
        if file is not None and any(
            (
                data_columns is not None,
                bijector is not None,
                conditional_columns is not None,
                latent is not None,
                data_error_model is not None,
                condition_error_model is not None,
                info is not None,
            )
        ):
            raise ValueError(
                "If providing a file, please do not provide any other parameters."
            )

        # if file is provided, load everything from the file
        if file is not None:
            # load the file
            with open(file, "rb") as handle:
                save_dict = pickle.load(handle)

            # make sure the saved file is for this class
            c = save_dict.pop("class")
            if c != self.__class__.__name__:
                raise TypeError(
                    f"This save file isn't a {self.__class__.__name__}. It is a {c}."
                )

            # load the ensemble from the dictionary
            self._ensemble = {
                name: Flow(_dictionary=flow_dict)
                for name, flow_dict in save_dict["ensemble"].items()
            }
            # load the metadata
            self.data_columns = save_dict["data_columns"]
            self.conditional_columns = save_dict["conditional_columns"]
            self.data_error_model = save_dict["data_error_model"]
            self.condition_error_model = save_dict["condition_error_model"]
            self.info = save_dict["info"]

            self._latent_info = save_dict["latent_info"]
            self.latent = getattr(distributions, self._latent_info[0])(
                *self._latent_info[1]
            )

        # otherwise create a new ensemble from the provided parameters
        else:
            # save the ensemble of flows
            self._ensemble = {
                f"Flow {i}": Flow(
                    data_columns=data_columns,
                    bijector=bijector,
                    conditional_columns=conditional_columns,
                    latent=latent,
                    data_error_model=data_error_model,
                    condition_error_model=condition_error_model,
                    autoscale_conditions=autoscale_conditions,
                    seed=i,
                    info=f"Flow {i}",
                )
                for i in range(N)
            }
            # save the metadata
            self.data_columns = data_columns
            self.conditional_columns = conditional_columns
            self.latent = self._ensemble["Flow 0"].latent
            self.data_error_model = data_error_model
            self.condition_error_model = condition_error_model
            self.info = info

    def log_prob(
        self,
        inputs: pd.DataFrame,
        err_samples: int = None,
        seed: int = None,
        returnEnsemble: bool = False,
    ) -> jnp.ndarray:
        """Calculates log probability density of inputs.

        Parameters
        ----------
        inputs : pd.DataFrame
            Input data for which log probability density is calculated.
            Every column in self.data_columns must be present.
            If self.conditional_columns is not None, those must be present
            as well. If other columns are present, they are ignored.
        err_samples : int; default=None
            Number of samples from the error distribution to average over for
            the log_prob calculation. If provided, Gaussian errors are assumed,
            and method will look for error columns in `inputs`. Error columns
            must end in `_err`. E.g. the error column for the variable `u` must
            be `u_err`. Zero error assumed for any missing error columns.
        seed : int; default=None
            Random seed for drawing the samples with Gaussian errors.
        returnEnsemble : bool; default=False
            If True, returns log_prob for each flow in the ensemble as an
            array of shape (inputs.shape[0], N flows in ensemble).
            If False, the prob is averaged over the flows in the ensemble,
            and the log of this average is returned as an array of shape
            (inputs.shape[0],)

        Returns
        -------
        jnp.ndarray
            For shape, see returnEnsemble description above.
        """

        # calculate log_prob for each flow in the ensemble
        ensemble = jnp.array(
            [
                flow.log_prob(inputs, err_samples, seed)
                for flow in self._ensemble.values()
            ]
        )

        # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
        ensemble = jnp.rollaxis(ensemble, axis=1)

        if returnEnsemble:
            # return the ensemble of log_probs
            return ensemble
        else:
            # return mean over ensemble
            # note we return log(mean prob) instead of just mean log_prob
            return jnp.log(jnp.exp(ensemble).mean(axis=1))

    def posterior(
        self,
        inputs: pd.DataFrame,
        column: str,
        grid: jnp.ndarray,
        marg_rules: dict = None,
        normalize: bool = True,
        err_samples: int = None,
        seed: int = None,
        batch_size: int = None,
        returnEnsemble: bool = False,
        nan_to_zero: bool = True,
    ) -> jnp.ndarray:
        """Calculates posterior distributions for the provided column.

        Calculates the conditional posterior distribution, assuming the
        data values in the other columns of the DataFrame.

        Parameters
        ----------
        inputs : pd.DataFrame
            Data on which the posterior distributions are conditioned.
            Must have columns matching self.data_columns, *except*
            for the column specified for the posterior (see below).
        column : str
            Name of the column for which the posterior distribution
            is calculated. Must be one of the columns in self.data_columns.
            However, whether or not this column is one of the columns in
            `inputs` is irrelevant.
        grid : jnp.ndarray
            Grid on which to calculate the posterior.
        marg_rules : dict; optional
            Dictionary with rules for marginalizing over missing variables.
            The dictionary must contain the key "flag", which gives the flag
            that indicates a missing value. E.g. if missing values are given
            the value 99, the dictionary should contain {"flag": 99}.
            The dictionary must also contain {"name": callable} for any
            variables that will need to be marginalized over, where name is
            the name of the variable, and callable is a callable that takes
            the row of variables nad returns a grid over which to marginalize
            the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}.
            Note: the callable for a given name must *always* return an array
            of the same length, regardless of the input row.
        normalize : boolean; default=True
            Whether to normalize the posterior so that it integrates to 1.
        err_samples : int; default=None
            Number of samples from the error distribution to average over for
            the posterior calculation. If provided, Gaussian errors are assumed,
            and method will look for error columns in `inputs`. Error columns
            must end in `_err`. E.g. the error column for the variable `u` must
            be `u_err`. Zero error assumed for any missing error columns.
        seed : int; default=None
            Random seed for drawing the samples with Gaussian errors.
        batch_size : int; default=None
            Size of batches in which to calculate posteriors. If None, all
            posteriors are calculated simultaneously. Simultaneous calculation
            is faster, but memory intensive for large data sets.
        returnEnsemble : bool; default=False
            If True, returns posterior for each flow in the ensemble as an
            array of shape (inputs.shape[0], N flows in ensemble, grid.size).
            If False, the posterior is averaged over the flows in the ensemble,
            and returned as an array of shape (inputs.shape[0], grid.size)
        nan_to_zero : bool; default=True
            Whether to convert NaN's to zero probability in the final pdfs.

        Returns
        -------
        jnp.ndarray
            For shape, see returnEnsemble description above.
        """

        # calculate posterior for each flow in the ensemble
        ensemble = jnp.array(
            [
                flow.posterior(
                    inputs=inputs,
                    column=column,
                    grid=grid,
                    marg_rules=marg_rules,
                    err_samples=err_samples,
                    seed=seed,
                    batch_size=batch_size,
                    normalize=False,
                    nan_to_zero=nan_to_zero,
                )
                for flow in self._ensemble.values()
            ]
        )

        # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
        ensemble = jnp.rollaxis(ensemble, axis=1)

        if returnEnsemble:
            # return the ensemble of posteriors
            if normalize:
                ensemble = ensemble.reshape(-1, grid.size)
                ensemble = ensemble / jnp.trapz(y=ensemble, x=grid).reshape(
                    -1, 1
                )
                ensemble = ensemble.reshape(inputs.shape[0], -1, grid.size)
            return ensemble
        else:
            # return mean over ensemble
            pdfs = ensemble.mean(axis=1)
            if normalize:
                pdfs = pdfs / jnp.trapz(y=pdfs, x=grid).reshape(-1, 1)
            return pdfs

    def sample(
        self,
        nsamples: int = 1,
        conditions: pd.DataFrame = None,
        save_conditions: bool = True,
        seed: int = None,
        returnEnsemble: bool = False,
    ) -> pd.DataFrame:
        """Returns samples from the ensemble.

        Parameters
        ----------
        nsamples : int; default=1
            The number of samples to be returned, either overall or per flow
            in the ensemble (see returnEnsemble below).
        conditions : pd.DataFrame; optional
            If this is a conditional flow, you must pass conditions for
            each sample. nsamples will be drawn for each row in conditions.
        save_conditions : bool; default=True
            If true, conditions will be saved in the DataFrame of samples
            that is returned.
        seed : int; optional
            Sets the random seed for the samples.
        returnEnsemble : bool; default=False
            If True, nsamples is drawn from each flow in the ensemble.
            If False, nsamples are drawn uniformly from the flows in the ensemble.

        Returns
        -------
        pd.DataFrame
            Pandas DataFrame of samples.
        """

        if returnEnsemble:
            # return nsamples for each flow in the ensemble
            return pd.concat(
                [
                    flow.sample(nsamples, conditions, save_conditions, seed)
                    for flow in self._ensemble.values()
                ],
                keys=self._ensemble.keys(),
            )
        else:
            # if this isn't a conditional flow, sampling is straightforward
            if conditions is None:
                # return nsamples drawn uniformly from the flows in the ensemble
                N = int(jnp.ceil(nsamples / len(self._ensemble)))
                samples = pd.concat(
                    [
                        flow.sample(N, conditions, save_conditions, seed)
                        for flow in self._ensemble.values()
                    ]
                )
                return samples.sample(nsamples, random_state=seed).reset_index(
                    drop=True
                )
            # if this is a conditional flow, it's a little more complicated...
            else:
                # if nsamples > 1, we duplicate the rows of the conditions
                if nsamples > 1:
                    conditions = pd.concat([conditions] * nsamples)

                # now the main sampling algorithm
                seed = np.random.randint(1e18) if seed is None else seed
                # if we are drawing more samples than the number of flows in
                # the ensemble, then we will shuffle the conditions and randomly
                # assign them to one of the constituent flows
                if conditions.shape[0] > len(self._ensemble):
                    # shuffle the conditions
                    conditions_shuffled = conditions.sample(
                        frac=1.0, random_state=int(seed / 1e9)
                    )
                    # split conditions into ~equal sized chunks
                    chunks = np.array_split(
                        conditions_shuffled, len(self._ensemble)
                    )
                    # shuffle the chunks
                    chunks = [
                        chunks[i]
                        for i in random.permutation(
                            random.PRNGKey(seed), jnp.arange(len(chunks))
                        )
                    ]
                    # sample from each flow, and return all the samples
                    return pd.concat(
                        [
                            flow.sample(
                                1, chunk, save_conditions, seed
                            ).set_index(chunk.index)
                            for flow, chunk in zip(
                                self._ensemble.values(), chunks
                            )
                        ]
                    ).sort_index()
                # however, if there are more flows in the ensemble than samples
                # being drawn, then we will randomly select flows for each condition
                else:
                    rng = np.random.default_rng(seed)
                    # randomly select a flow to sample from for each condition
                    flows = rng.choice(
                        list(self._ensemble.values()),
                        size=conditions.shape[0],
                        replace=True,
                    )
                    # sample from each flow and return all the samples together
                    seeds = rng.integers(1e18, size=conditions.shape[0])
                    return pd.concat(
                        [
                            flow.sample(
                                1,
                                conditions[i : i + 1],
                                save_conditions,
                                new_seed,
                            )
                            for i, (flow, new_seed) in enumerate(
                                zip(flows, seeds)
                            )
                        ],
                    ).set_index(conditions.index)

    def save(self, file: str) -> None:
        """Saves the ensemble to a file.

        Pickles the ensemble and saves it to a file that can be passed as
        the `file` argument during flow instantiation.

        WARNING: Currently, this method only works for bijectors that are
        implemented in the `bijectors` module. If you want to save a flow
        with a custom bijector, you either need to add the bijector to that
        module, or handle the saving and loading on your end.

        Parameters
        ----------
        file : str
            Path to where the ensemble will be saved.
            Extension `.pkl` will be appended if not already present.
        """
        save_dict = {
            "data_columns": self.data_columns,
            "conditional_columns": self.conditional_columns,
            "latent_info": self.latent.info,
            "data_error_model": self.data_error_model,
            "condition_error_model": self.condition_error_model,
            "info": self.info,
            "class": self.__class__.__name__,
            "ensemble": {
                name: flow._save_dict()
                for name, flow in self._ensemble.items()
            },
        }

        with open(file, "wb") as handle:
            pickle.dump(save_dict, handle, recurse=True)

    def train(
        self,
        inputs: pd.DataFrame,
        val_set: pd.DataFrame = None,
        epochs: int = 50,
        batch_size: int = 1024,
        optimizer: Callable = None,
        loss_fn: Callable = None,
        convolve_errs: bool = False,
        patience: int = None,
        best_params: bool = True,
        seed: int = 0,
        verbose: bool = False,
        progress_bar: bool = False,
    ) -> dict:
        """Trains the normalizing flows on the provided inputs.

        Parameters
        ----------
        inputs : pd.DataFrame
            Data on which to train the normalizing flows.
            Must have columns matching self.data_columns.
        val_set : pd.DataFrame; default=None
            Validation set, of same format as inputs. If provided,
            validation loss will be calculated at the end of each epoch.
        epochs : int; default=50
            Number of epochs to train.
        batch_size : int; default=1024
            Batch size for training.
        optimizer : optax optimizer
            An optimizer from Optax. default = optax.adam(learning_rate=1e-3)
            see https://optax.readthedocs.io/en/latest/index.html for more.
        loss_fn : Callable; optional
            A function to calculate the loss: loss = loss_fn(params, x).
            If not provided, will be -mean(log_prob).
        convolve_errs : bool; default=False
            Whether to draw new data from the error distributions during
            each epoch of training. Method will look for error columns in
            `inputs`. Error columns must end in `_err`. E.g. the error column
            for the variable `u` must be `u_err`. Zero error assumed for
            any missing error columns. The error distribution is set during
            ensemble instantiation.
        patience : int; optional
            Factor that controls early stopping. Training will stop if the
            loss doesn't decrease for this number of epochs.
        best_params : bool; default=True
            Whether to use the params from the epoch with the lowest loss.
            Note if a validation set is provided, the epoch with the lowest
            validation loss is chosen. If False, the params from the final
            epoch are saved.
        seed : int; default=0
            A random seed to control the batching and the (optional)
            error sampling.
        verbose : bool; default=False
            If true, print the training loss every 5% of epochs.
        progress_bar : bool; default=False
            If true, display a tqdm progress bar during training.

        Returns
        -------
        dict
            Dictionary of training losses from every epoch for each flow
            in the ensemble.
        """

        # generate random seeds for each flow
        rng = np.random.default_rng(seed)
        seeds = rng.integers(1e9, size=len(self._ensemble))

        loss_dict = dict()

        for i, (name, flow) in enumerate(self._ensemble.items()):
            if verbose:
                print(name)

            loss_dict[name] = flow.train(
                inputs=inputs,
                val_set=val_set,
                epochs=epochs,
                batch_size=batch_size,
                optimizer=optimizer,
                loss_fn=loss_fn,
                convolve_errs=convolve_errs,
                patience=patience,
                best_params=best_params,
                seed=seeds[i],
                verbose=verbose,
                progress_bar=progress_bar,
            )

        return loss_dict

__init__(data_columns=None, bijector=None, latent=None, conditional_columns=None, data_error_model=None, condition_error_model=None, autoscale_conditions=True, N=1, info=None, file=None)

Instantiate an ensemble of normalizing flows.

Note that while all of the init parameters are technically optional, you must provide either data_columns and bijector OR file. In addition, if a file is provided, all other parameters must be None.

Parameters:

Name Type Description Default
data_columns Sequence[str]

Tuple, list, or other container of column names. These are the columns the flows expect/produce in DataFrames.

None
bijector Bijector Call

A Bijector call that consists of the bijector InitFunction that initializes the bijector and the tuple of Bijector Info. Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc. If not provided, the bijector can be set later using flow.set_bijector, or by calling flow.train, in which case the default bijector will be used. The default bijector is ShiftBounds -> RollingSplineCoupling, where the range of shift bounds is learned from the training data, and the dimensions of RollingSplineCoupling is inferred. The default bijector assumes that the latent has support [-5, 5] for every dimension.

None
latent distributions.LatentDist

The latent distribution for the normalizing flow. Can be any of the distributions from pzflow.distributions. If not provided, a uniform distribution is used with input_dim = len(data_columns), and B=5.

None
conditional_columns Sequence[str]

Names of columns on which to condition the normalizing flows.

None
data_error_model Callable

A callable that defines the error model for data variables. data_error_model must take key, X, Xerr, nsamples as arguments: - key is a jax rng key, e.g. jax.random.PRNGKey(0) - X is 2D array of data variables, where the order of variables matches the order of the columns in data_columns - Xerr is the corresponding 2D array of errors - nsamples is number of samples to draw from error distribution data_error_model must return an array of samples with the shape (X.shape[0], nsamples, X.shape[1]). If data_error_model is not provided, Gaussian error model assumed.

None
condition_error_model Callable

A callable that defines the error model for conditional variables. condition_error_model must take key, X, Xerr, nsamples, where: - key is a jax rng key, e.g. jax.random.PRNGKey(0) - X is 2D array of conditional variables, where the order of variables matches order of columns in conditional_columns - Xerr is the corresponding 2D array of errors - nsamples is number of samples to draw from error distribution condition_error_model must return array of samples with shape (X.shape[0], nsamples, X.shape[1]). If condition_error_model is not provided, Gaussian error model assumed.

None
autoscale_conditions bool

Sets whether or not conditions are automatically standard scaled when passed to a conditional flow. I recommend you leave as True.

True
N int

The number of flows in the ensemble.

1
info Any

An object to attach to the info attribute.

None
file str

Path to file from which to load a pretrained flow ensemble. If a file is provided, all other parameters must be None.

None
Source code in pzflow/flowEnsemble.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def __init__(
    self,
    data_columns: Sequence[str] = None,
    bijector: Tuple[InitFunction, Bijector_Info] = None,
    latent: distributions.LatentDist = None,
    conditional_columns: Sequence[str] = None,
    data_error_model: Callable = None,
    condition_error_model: Callable = None,
    autoscale_conditions: bool = True,
    N: int = 1,
    info: Any = None,
    file: str = None,
) -> None:
    """Instantiate an ensemble of normalizing flows.

    Note that while all of the init parameters are technically optional,
    you must provide either data_columns and bijector OR file.
    In addition, if a file is provided, all other parameters must be None.

    Parameters
    ----------
    data_columns : Sequence[str]; optional
        Tuple, list, or other container of column names.
        These are the columns the flows expect/produce in DataFrames.
    bijector : Bijector Call; optional
        A Bijector call that consists of the bijector InitFunction that
        initializes the bijector and the tuple of Bijector Info.
        Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.
        If not provided, the bijector can be set later using
        flow.set_bijector, or by calling flow.train, in which case the
        default bijector will be used. The default bijector is
        ShiftBounds -> RollingSplineCoupling, where the range of shift
        bounds is learned from the training data, and the dimensions of
        RollingSplineCoupling is inferred. The default bijector assumes
        that the latent has support [-5, 5] for every dimension.
    latent : distributions.LatentDist; optional
        The latent distribution for the normalizing flow. Can be any of
        the distributions from pzflow.distributions. If not provided,
        a uniform distribution is used with input_dim = len(data_columns),
        and B=5.
    conditional_columns : Sequence[str]; optional
        Names of columns on which to condition the normalizing flows.
    data_error_model : Callable; optional
        A callable that defines the error model for data variables.
        data_error_model must take key, X, Xerr, nsamples as arguments:
            - key is a jax rng key, e.g. jax.random.PRNGKey(0)
            - X is 2D array of data variables, where the order of variables
                matches the order of the columns in data_columns
            - Xerr is the corresponding 2D array of errors
            - nsamples is number of samples to draw from error distribution
        data_error_model must return an array of samples with the shape
        (X.shape[0], nsamples, X.shape[1]).
        If data_error_model is not provided, Gaussian error model assumed.
    condition_error_model : Callable; optional
        A callable that defines the error model for conditional variables.
        condition_error_model must take key, X, Xerr, nsamples, where:
            - key is a jax rng key, e.g. jax.random.PRNGKey(0)
            - X is 2D array of conditional variables, where the order of
                variables matches order of columns in conditional_columns
            - Xerr is the corresponding 2D array of errors
            - nsamples is number of samples to draw from error distribution
        condition_error_model must return array of samples with shape
        (X.shape[0], nsamples, X.shape[1]).
        If condition_error_model is not provided, Gaussian error model
        assumed.
    autoscale_conditions : bool; default=True
        Sets whether or not conditions are automatically standard scaled
        when passed to a conditional flow. I recommend you leave as True.
    N : int; default=1
        The number of flows in the ensemble.
    info : Any; optional
        An object to attach to the info attribute.
    file : str; optional
        Path to file from which to load a pretrained flow ensemble.
        If a file is provided, all other parameters must be None.
    """

    # validate parameters
    if data_columns is None and file is None:
        raise ValueError("You must provide data_columns OR file.")
    if file is not None and any(
        (
            data_columns is not None,
            bijector is not None,
            conditional_columns is not None,
            latent is not None,
            data_error_model is not None,
            condition_error_model is not None,
            info is not None,
        )
    ):
        raise ValueError(
            "If providing a file, please do not provide any other parameters."
        )

    # if file is provided, load everything from the file
    if file is not None:
        # load the file
        with open(file, "rb") as handle:
            save_dict = pickle.load(handle)

        # make sure the saved file is for this class
        c = save_dict.pop("class")
        if c != self.__class__.__name__:
            raise TypeError(
                f"This save file isn't a {self.__class__.__name__}. It is a {c}."
            )

        # load the ensemble from the dictionary
        self._ensemble = {
            name: Flow(_dictionary=flow_dict)
            for name, flow_dict in save_dict["ensemble"].items()
        }
        # load the metadata
        self.data_columns = save_dict["data_columns"]
        self.conditional_columns = save_dict["conditional_columns"]
        self.data_error_model = save_dict["data_error_model"]
        self.condition_error_model = save_dict["condition_error_model"]
        self.info = save_dict["info"]

        self._latent_info = save_dict["latent_info"]
        self.latent = getattr(distributions, self._latent_info[0])(
            *self._latent_info[1]
        )

    # otherwise create a new ensemble from the provided parameters
    else:
        # save the ensemble of flows
        self._ensemble = {
            f"Flow {i}": Flow(
                data_columns=data_columns,
                bijector=bijector,
                conditional_columns=conditional_columns,
                latent=latent,
                data_error_model=data_error_model,
                condition_error_model=condition_error_model,
                autoscale_conditions=autoscale_conditions,
                seed=i,
                info=f"Flow {i}",
            )
            for i in range(N)
        }
        # save the metadata
        self.data_columns = data_columns
        self.conditional_columns = conditional_columns
        self.latent = self._ensemble["Flow 0"].latent
        self.data_error_model = data_error_model
        self.condition_error_model = condition_error_model
        self.info = info

log_prob(inputs, err_samples=None, seed=None, returnEnsemble=False)

Calculates log probability density of inputs.

Parameters:

Name Type Description Default
inputs pd.DataFrame

Input data for which log probability density is calculated. Every column in self.data_columns must be present. If self.conditional_columns is not None, those must be present as well. If other columns are present, they are ignored.

required
err_samples int

Number of samples from the error distribution to average over for the log_prob calculation. If provided, Gaussian errors are assumed, and method will look for error columns in inputs. Error columns must end in _err. E.g. the error column for the variable u must be u_err. Zero error assumed for any missing error columns.

None
seed int

Random seed for drawing the samples with Gaussian errors.

None
returnEnsemble bool

If True, returns log_prob for each flow in the ensemble as an array of shape (inputs.shape[0], N flows in ensemble). If False, the prob is averaged over the flows in the ensemble, and the log of this average is returned as an array of shape (inputs.shape[0],)

False

Returns:

Type Description
jnp.ndarray

For shape, see returnEnsemble description above.

Source code in pzflow/flowEnsemble.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def log_prob(
    self,
    inputs: pd.DataFrame,
    err_samples: int = None,
    seed: int = None,
    returnEnsemble: bool = False,
) -> jnp.ndarray:
    """Calculates log probability density of inputs.

    Parameters
    ----------
    inputs : pd.DataFrame
        Input data for which log probability density is calculated.
        Every column in self.data_columns must be present.
        If self.conditional_columns is not None, those must be present
        as well. If other columns are present, they are ignored.
    err_samples : int; default=None
        Number of samples from the error distribution to average over for
        the log_prob calculation. If provided, Gaussian errors are assumed,
        and method will look for error columns in `inputs`. Error columns
        must end in `_err`. E.g. the error column for the variable `u` must
        be `u_err`. Zero error assumed for any missing error columns.
    seed : int; default=None
        Random seed for drawing the samples with Gaussian errors.
    returnEnsemble : bool; default=False
        If True, returns log_prob for each flow in the ensemble as an
        array of shape (inputs.shape[0], N flows in ensemble).
        If False, the prob is averaged over the flows in the ensemble,
        and the log of this average is returned as an array of shape
        (inputs.shape[0],)

    Returns
    -------
    jnp.ndarray
        For shape, see returnEnsemble description above.
    """

    # calculate log_prob for each flow in the ensemble
    ensemble = jnp.array(
        [
            flow.log_prob(inputs, err_samples, seed)
            for flow in self._ensemble.values()
        ]
    )

    # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
    ensemble = jnp.rollaxis(ensemble, axis=1)

    if returnEnsemble:
        # return the ensemble of log_probs
        return ensemble
    else:
        # return mean over ensemble
        # note we return log(mean prob) instead of just mean log_prob
        return jnp.log(jnp.exp(ensemble).mean(axis=1))

posterior(inputs, column, grid, marg_rules=None, normalize=True, err_samples=None, seed=None, batch_size=None, returnEnsemble=False, nan_to_zero=True)

Calculates posterior distributions for the provided column.

Calculates the conditional posterior distribution, assuming the data values in the other columns of the DataFrame.

Parameters:

Name Type Description Default
inputs pd.DataFrame

Data on which the posterior distributions are conditioned. Must have columns matching self.data_columns, except for the column specified for the posterior (see below).

required
column str

Name of the column for which the posterior distribution is calculated. Must be one of the columns in self.data_columns. However, whether or not this column is one of the columns in inputs is irrelevant.

required
grid jnp.ndarray

Grid on which to calculate the posterior.

required
marg_rules dict

Dictionary with rules for marginalizing over missing variables. The dictionary must contain the key "flag", which gives the flag that indicates a missing value. E.g. if missing values are given the value 99, the dictionary should contain {"flag": 99}. The dictionary must also contain {"name": callable} for any variables that will need to be marginalized over, where name is the name of the variable, and callable is a callable that takes the row of variables nad returns a grid over which to marginalize the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}. Note: the callable for a given name must always return an array of the same length, regardless of the input row.

None
normalize boolean

Whether to normalize the posterior so that it integrates to 1.

True
err_samples int

Number of samples from the error distribution to average over for the posterior calculation. If provided, Gaussian errors are assumed, and method will look for error columns in inputs. Error columns must end in _err. E.g. the error column for the variable u must be u_err. Zero error assumed for any missing error columns.

None
seed int

Random seed for drawing the samples with Gaussian errors.

None
batch_size int

Size of batches in which to calculate posteriors. If None, all posteriors are calculated simultaneously. Simultaneous calculation is faster, but memory intensive for large data sets.

None
returnEnsemble bool

If True, returns posterior for each flow in the ensemble as an array of shape (inputs.shape[0], N flows in ensemble, grid.size). If False, the posterior is averaged over the flows in the ensemble, and returned as an array of shape (inputs.shape[0], grid.size)

False
nan_to_zero bool

Whether to convert NaN's to zero probability in the final pdfs.

True

Returns:

Type Description
jnp.ndarray

For shape, see returnEnsemble description above.

Source code in pzflow/flowEnsemble.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def posterior(
    self,
    inputs: pd.DataFrame,
    column: str,
    grid: jnp.ndarray,
    marg_rules: dict = None,
    normalize: bool = True,
    err_samples: int = None,
    seed: int = None,
    batch_size: int = None,
    returnEnsemble: bool = False,
    nan_to_zero: bool = True,
) -> jnp.ndarray:
    """Calculates posterior distributions for the provided column.

    Calculates the conditional posterior distribution, assuming the
    data values in the other columns of the DataFrame.

    Parameters
    ----------
    inputs : pd.DataFrame
        Data on which the posterior distributions are conditioned.
        Must have columns matching self.data_columns, *except*
        for the column specified for the posterior (see below).
    column : str
        Name of the column for which the posterior distribution
        is calculated. Must be one of the columns in self.data_columns.
        However, whether or not this column is one of the columns in
        `inputs` is irrelevant.
    grid : jnp.ndarray
        Grid on which to calculate the posterior.
    marg_rules : dict; optional
        Dictionary with rules for marginalizing over missing variables.
        The dictionary must contain the key "flag", which gives the flag
        that indicates a missing value. E.g. if missing values are given
        the value 99, the dictionary should contain {"flag": 99}.
        The dictionary must also contain {"name": callable} for any
        variables that will need to be marginalized over, where name is
        the name of the variable, and callable is a callable that takes
        the row of variables nad returns a grid over which to marginalize
        the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}.
        Note: the callable for a given name must *always* return an array
        of the same length, regardless of the input row.
    normalize : boolean; default=True
        Whether to normalize the posterior so that it integrates to 1.
    err_samples : int; default=None
        Number of samples from the error distribution to average over for
        the posterior calculation. If provided, Gaussian errors are assumed,
        and method will look for error columns in `inputs`. Error columns
        must end in `_err`. E.g. the error column for the variable `u` must
        be `u_err`. Zero error assumed for any missing error columns.
    seed : int; default=None
        Random seed for drawing the samples with Gaussian errors.
    batch_size : int; default=None
        Size of batches in which to calculate posteriors. If None, all
        posteriors are calculated simultaneously. Simultaneous calculation
        is faster, but memory intensive for large data sets.
    returnEnsemble : bool; default=False
        If True, returns posterior for each flow in the ensemble as an
        array of shape (inputs.shape[0], N flows in ensemble, grid.size).
        If False, the posterior is averaged over the flows in the ensemble,
        and returned as an array of shape (inputs.shape[0], grid.size)
    nan_to_zero : bool; default=True
        Whether to convert NaN's to zero probability in the final pdfs.

    Returns
    -------
    jnp.ndarray
        For shape, see returnEnsemble description above.
    """

    # calculate posterior for each flow in the ensemble
    ensemble = jnp.array(
        [
            flow.posterior(
                inputs=inputs,
                column=column,
                grid=grid,
                marg_rules=marg_rules,
                err_samples=err_samples,
                seed=seed,
                batch_size=batch_size,
                normalize=False,
                nan_to_zero=nan_to_zero,
            )
            for flow in self._ensemble.values()
        ]
    )

    # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
    ensemble = jnp.rollaxis(ensemble, axis=1)

    if returnEnsemble:
        # return the ensemble of posteriors
        if normalize:
            ensemble = ensemble.reshape(-1, grid.size)
            ensemble = ensemble / jnp.trapz(y=ensemble, x=grid).reshape(
                -1, 1
            )
            ensemble = ensemble.reshape(inputs.shape[0], -1, grid.size)
        return ensemble
    else:
        # return mean over ensemble
        pdfs = ensemble.mean(axis=1)
        if normalize:
            pdfs = pdfs / jnp.trapz(y=pdfs, x=grid).reshape(-1, 1)
        return pdfs

sample(nsamples=1, conditions=None, save_conditions=True, seed=None, returnEnsemble=False)

Returns samples from the ensemble.

Parameters:

Name Type Description Default
nsamples int

The number of samples to be returned, either overall or per flow in the ensemble (see returnEnsemble below).

1
conditions pd.DataFrame

If this is a conditional flow, you must pass conditions for each sample. nsamples will be drawn for each row in conditions.

None
save_conditions bool

If true, conditions will be saved in the DataFrame of samples that is returned.

True
seed int

Sets the random seed for the samples.

None
returnEnsemble bool

If True, nsamples is drawn from each flow in the ensemble. If False, nsamples are drawn uniformly from the flows in the ensemble.

False

Returns:

Type Description
pd.DataFrame

Pandas DataFrame of samples.

Source code in pzflow/flowEnsemble.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def sample(
    self,
    nsamples: int = 1,
    conditions: pd.DataFrame = None,
    save_conditions: bool = True,
    seed: int = None,
    returnEnsemble: bool = False,
) -> pd.DataFrame:
    """Returns samples from the ensemble.

    Parameters
    ----------
    nsamples : int; default=1
        The number of samples to be returned, either overall or per flow
        in the ensemble (see returnEnsemble below).
    conditions : pd.DataFrame; optional
        If this is a conditional flow, you must pass conditions for
        each sample. nsamples will be drawn for each row in conditions.
    save_conditions : bool; default=True
        If true, conditions will be saved in the DataFrame of samples
        that is returned.
    seed : int; optional
        Sets the random seed for the samples.
    returnEnsemble : bool; default=False
        If True, nsamples is drawn from each flow in the ensemble.
        If False, nsamples are drawn uniformly from the flows in the ensemble.

    Returns
    -------
    pd.DataFrame
        Pandas DataFrame of samples.
    """

    if returnEnsemble:
        # return nsamples for each flow in the ensemble
        return pd.concat(
            [
                flow.sample(nsamples, conditions, save_conditions, seed)
                for flow in self._ensemble.values()
            ],
            keys=self._ensemble.keys(),
        )
    else:
        # if this isn't a conditional flow, sampling is straightforward
        if conditions is None:
            # return nsamples drawn uniformly from the flows in the ensemble
            N = int(jnp.ceil(nsamples / len(self._ensemble)))
            samples = pd.concat(
                [
                    flow.sample(N, conditions, save_conditions, seed)
                    for flow in self._ensemble.values()
                ]
            )
            return samples.sample(nsamples, random_state=seed).reset_index(
                drop=True
            )
        # if this is a conditional flow, it's a little more complicated...
        else:
            # if nsamples > 1, we duplicate the rows of the conditions
            if nsamples > 1:
                conditions = pd.concat([conditions] * nsamples)

            # now the main sampling algorithm
            seed = np.random.randint(1e18) if seed is None else seed
            # if we are drawing more samples than the number of flows in
            # the ensemble, then we will shuffle the conditions and randomly
            # assign them to one of the constituent flows
            if conditions.shape[0] > len(self._ensemble):
                # shuffle the conditions
                conditions_shuffled = conditions.sample(
                    frac=1.0, random_state=int(seed / 1e9)
                )
                # split conditions into ~equal sized chunks
                chunks = np.array_split(
                    conditions_shuffled, len(self._ensemble)
                )
                # shuffle the chunks
                chunks = [
                    chunks[i]
                    for i in random.permutation(
                        random.PRNGKey(seed), jnp.arange(len(chunks))
                    )
                ]
                # sample from each flow, and return all the samples
                return pd.concat(
                    [
                        flow.sample(
                            1, chunk, save_conditions, seed
                        ).set_index(chunk.index)
                        for flow, chunk in zip(
                            self._ensemble.values(), chunks
                        )
                    ]
                ).sort_index()
            # however, if there are more flows in the ensemble than samples
            # being drawn, then we will randomly select flows for each condition
            else:
                rng = np.random.default_rng(seed)
                # randomly select a flow to sample from for each condition
                flows = rng.choice(
                    list(self._ensemble.values()),
                    size=conditions.shape[0],
                    replace=True,
                )
                # sample from each flow and return all the samples together
                seeds = rng.integers(1e18, size=conditions.shape[0])
                return pd.concat(
                    [
                        flow.sample(
                            1,
                            conditions[i : i + 1],
                            save_conditions,
                            new_seed,
                        )
                        for i, (flow, new_seed) in enumerate(
                            zip(flows, seeds)
                        )
                    ],
                ).set_index(conditions.index)

save(file)

Saves the ensemble to a file.

Pickles the ensemble and saves it to a file that can be passed as the file argument during flow instantiation.

WARNING: Currently, this method only works for bijectors that are implemented in the bijectors module. If you want to save a flow with a custom bijector, you either need to add the bijector to that module, or handle the saving and loading on your end.

Parameters:

Name Type Description Default
file str

Path to where the ensemble will be saved. Extension .pkl will be appended if not already present.

required
Source code in pzflow/flowEnsemble.py
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def save(self, file: str) -> None:
    """Saves the ensemble to a file.

    Pickles the ensemble and saves it to a file that can be passed as
    the `file` argument during flow instantiation.

    WARNING: Currently, this method only works for bijectors that are
    implemented in the `bijectors` module. If you want to save a flow
    with a custom bijector, you either need to add the bijector to that
    module, or handle the saving and loading on your end.

    Parameters
    ----------
    file : str
        Path to where the ensemble will be saved.
        Extension `.pkl` will be appended if not already present.
    """
    save_dict = {
        "data_columns": self.data_columns,
        "conditional_columns": self.conditional_columns,
        "latent_info": self.latent.info,
        "data_error_model": self.data_error_model,
        "condition_error_model": self.condition_error_model,
        "info": self.info,
        "class": self.__class__.__name__,
        "ensemble": {
            name: flow._save_dict()
            for name, flow in self._ensemble.items()
        },
    }

    with open(file, "wb") as handle:
        pickle.dump(save_dict, handle, recurse=True)

train(inputs, val_set=None, epochs=50, batch_size=1024, optimizer=None, loss_fn=None, convolve_errs=False, patience=None, best_params=True, seed=0, verbose=False, progress_bar=False)

Trains the normalizing flows on the provided inputs.

Parameters:

Name Type Description Default
inputs pd.DataFrame

Data on which to train the normalizing flows. Must have columns matching self.data_columns.

required
val_set pd.DataFrame

Validation set, of same format as inputs. If provided, validation loss will be calculated at the end of each epoch.

None
epochs int

Number of epochs to train.

50
batch_size int

Batch size for training.

1024
optimizer optax optimizer

An optimizer from Optax. default = optax.adam(learning_rate=1e-3) see https://optax.readthedocs.io/en/latest/index.html for more.

None
loss_fn Callable

A function to calculate the loss: loss = loss_fn(params, x). If not provided, will be -mean(log_prob).

None
convolve_errs bool

Whether to draw new data from the error distributions during each epoch of training. Method will look for error columns in inputs. Error columns must end in _err. E.g. the error column for the variable u must be u_err. Zero error assumed for any missing error columns. The error distribution is set during ensemble instantiation.

False
patience int

Factor that controls early stopping. Training will stop if the loss doesn't decrease for this number of epochs.

None
best_params bool

Whether to use the params from the epoch with the lowest loss. Note if a validation set is provided, the epoch with the lowest validation loss is chosen. If False, the params from the final epoch are saved.

True
seed int

A random seed to control the batching and the (optional) error sampling.

0
verbose bool

If true, print the training loss every 5% of epochs.

False
progress_bar bool

If true, display a tqdm progress bar during training.

False

Returns:

Type Description
dict

Dictionary of training losses from every epoch for each flow in the ensemble.

Source code in pzflow/flowEnsemble.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
def train(
    self,
    inputs: pd.DataFrame,
    val_set: pd.DataFrame = None,
    epochs: int = 50,
    batch_size: int = 1024,
    optimizer: Callable = None,
    loss_fn: Callable = None,
    convolve_errs: bool = False,
    patience: int = None,
    best_params: bool = True,
    seed: int = 0,
    verbose: bool = False,
    progress_bar: bool = False,
) -> dict:
    """Trains the normalizing flows on the provided inputs.

    Parameters
    ----------
    inputs : pd.DataFrame
        Data on which to train the normalizing flows.
        Must have columns matching self.data_columns.
    val_set : pd.DataFrame; default=None
        Validation set, of same format as inputs. If provided,
        validation loss will be calculated at the end of each epoch.
    epochs : int; default=50
        Number of epochs to train.
    batch_size : int; default=1024
        Batch size for training.
    optimizer : optax optimizer
        An optimizer from Optax. default = optax.adam(learning_rate=1e-3)
        see https://optax.readthedocs.io/en/latest/index.html for more.
    loss_fn : Callable; optional
        A function to calculate the loss: loss = loss_fn(params, x).
        If not provided, will be -mean(log_prob).
    convolve_errs : bool; default=False
        Whether to draw new data from the error distributions during
        each epoch of training. Method will look for error columns in
        `inputs`. Error columns must end in `_err`. E.g. the error column
        for the variable `u` must be `u_err`. Zero error assumed for
        any missing error columns. The error distribution is set during
        ensemble instantiation.
    patience : int; optional
        Factor that controls early stopping. Training will stop if the
        loss doesn't decrease for this number of epochs.
    best_params : bool; default=True
        Whether to use the params from the epoch with the lowest loss.
        Note if a validation set is provided, the epoch with the lowest
        validation loss is chosen. If False, the params from the final
        epoch are saved.
    seed : int; default=0
        A random seed to control the batching and the (optional)
        error sampling.
    verbose : bool; default=False
        If true, print the training loss every 5% of epochs.
    progress_bar : bool; default=False
        If true, display a tqdm progress bar during training.

    Returns
    -------
    dict
        Dictionary of training losses from every epoch for each flow
        in the ensemble.
    """

    # generate random seeds for each flow
    rng = np.random.default_rng(seed)
    seeds = rng.integers(1e9, size=len(self._ensemble))

    loss_dict = dict()

    for i, (name, flow) in enumerate(self._ensemble.items()):
        if verbose:
            print(name)

        loss_dict[name] = flow.train(
            inputs=inputs,
            val_set=val_set,
            epochs=epochs,
            batch_size=batch_size,
            optimizer=optimizer,
            loss_fn=loss_fn,
            convolve_errs=convolve_errs,
            patience=patience,
            best_params=best_params,
            seed=seeds[i],
            verbose=verbose,
            progress_bar=progress_bar,
        )

    return loss_dict

utils

Define utility functions for use in other modules.

DenseReluNetwork(out_dim, hidden_layers, hidden_dim)

Create a dense neural network with Relu after hidden layers.

Parameters:

Name Type Description Default
out_dim int

The output dimension.

required
hidden_layers int

The number of hidden layers

required
hidden_dim int

The dimension of the hidden layers

required

Returns:

Name Type Description
init_fun function

The function that initializes the network. Note that this is the init_function defined in the Jax stax module, which is different from the functions of my InitFunction class.

forward_fun function

The function that passes the inputs through the neural network.

Source code in pzflow/utils.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def DenseReluNetwork(
    out_dim: int, hidden_layers: int, hidden_dim: int
) -> Tuple[Callable, Callable]:
    """Create a dense neural network with Relu after hidden layers.

    Parameters
    ----------
    out_dim : int
        The output dimension.
    hidden_layers : int
        The number of hidden layers
    hidden_dim : int
        The dimension of the hidden layers

    Returns
    -------
    init_fun : function
        The function that initializes the network. Note that this is the
        init_function defined in the Jax stax module, which is different
        from the functions of my InitFunction class.
    forward_fun : function
        The function that passes the inputs through the neural network.
    """
    init_fun, forward_fun = serial(
        *(Dense(hidden_dim), LeakyRelu) * hidden_layers,
        Dense(out_dim),
    )
    return init_fun, forward_fun

RationalQuadraticSpline(inputs, W, H, D, B, periodic=False, inverse=False)

Apply rational quadratic spline to inputs and return outputs with log_det.

Applies the piecewise rational quadratic spline developed in [1].

Parameters:

Name Type Description Default
inputs jnp.ndarray

The inputs to be transformed.

required
W jnp.ndarray

The widths of the spline bins.

required
H jnp.ndarray

The heights of the spline bins.

required
D jnp.ndarray

The derivatives of the inner spline knots.

required
B float

Range of the splines. Outside of (-B,B), the transformation is just the identity.

required
inverse bool

If True, perform the inverse transformation. Otherwise perform the forward transformation.

False
periodic bool

Whether to make this a periodic, Circular Spline [2].

False

Returns:

Name Type Description
outputs jnp.ndarray

The result of applying the splines to the inputs.

log_det jnp.ndarray

The log determinant of the Jacobian at the inputs.

References

[1] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios. Neural Spline Flows. arXiv:1906.04032, 2019. https://arxiv.org/abs/1906.04032 [2] Rezende, Danilo Jimenez et al. Normalizing Flows on Tori and Spheres. arxiv:2002.02428, 2020 http://arxiv.org/abs/2002.02428

Source code in pzflow/utils.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def RationalQuadraticSpline(
    inputs: jnp.ndarray,
    W: jnp.ndarray,
    H: jnp.ndarray,
    D: jnp.ndarray,
    B: float,
    periodic: bool = False,
    inverse: bool = False,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Apply rational quadratic spline to inputs and return outputs with log_det.

    Applies the piecewise rational quadratic spline developed in [1].

    Parameters
    ----------
    inputs : jnp.ndarray
        The inputs to be transformed.
    W : jnp.ndarray
        The widths of the spline bins.
    H : jnp.ndarray
        The heights of the spline bins.
    D : jnp.ndarray
        The derivatives of the inner spline knots.
    B : float
        Range of the splines.
        Outside of (-B,B), the transformation is just the identity.
    inverse : bool; default=False
        If True, perform the inverse transformation.
        Otherwise perform the forward transformation.
    periodic : bool; default=False
        Whether to make this a periodic, Circular Spline [2].

    Returns
    -------
    outputs : jnp.ndarray
        The result of applying the splines to the inputs.
    log_det : jnp.ndarray
        The log determinant of the Jacobian at the inputs.

    References
    ----------
    [1] Conor Durkan, Artur Bekasov, Iain Murray, George Papamakarios.
        Neural Spline Flows. arXiv:1906.04032, 2019.
        https://arxiv.org/abs/1906.04032
    [2] Rezende, Danilo Jimenez et al.
        Normalizing Flows on Tori and Spheres. arxiv:2002.02428, 2020
        http://arxiv.org/abs/2002.02428
    """
    # knot x-positions
    xk = jnp.pad(
        -B + jnp.cumsum(W, axis=-1),
        [(0, 0)] * (len(W.shape) - 1) + [(1, 0)],
        mode="constant",
        constant_values=-B,
    )
    # knot y-positions
    yk = jnp.pad(
        -B + jnp.cumsum(H, axis=-1),
        [(0, 0)] * (len(H.shape) - 1) + [(1, 0)],
        mode="constant",
        constant_values=-B,
    )
    # knot derivatives
    if periodic:
        dk = jnp.pad(D, [(0, 0)] * (len(D.shape) - 1) + [(1, 0)], mode="wrap")
    else:
        dk = jnp.pad(
            D,
            [(0, 0)] * (len(D.shape) - 1) + [(1, 1)],
            mode="constant",
            constant_values=1,
        )
    # knot slopes
    sk = H / W

    # if not periodic, out-of-bounds inputs will have identity applied
    # if periodic, we map the input into the appropriate region inside
    # the period. For now, we will pretend all inputs are periodic.
    # This makes sure that out-of-bounds inputs don't cause problems
    # with the spline, but for the non-periodic case, we will replace
    # these with their original values at the end
    out_of_bounds = (inputs <= -B) | (inputs >= B)
    masked_inputs = jnp.where(out_of_bounds, jnp.abs(inputs) - B, inputs)

    # find bin for each input
    if inverse:
        idx = jnp.sum(yk <= masked_inputs[..., None], axis=-1)[..., None] - 1
    else:
        idx = jnp.sum(xk <= masked_inputs[..., None], axis=-1)[..., None] - 1

    # get kx, ky, kyp1, kd, kdp1, kw, ks for the bin corresponding to each input
    input_xk = jnp.take_along_axis(xk, idx, -1)[..., 0]
    input_yk = jnp.take_along_axis(yk, idx, -1)[..., 0]
    input_dk = jnp.take_along_axis(dk, idx, -1)[..., 0]
    input_dkp1 = jnp.take_along_axis(dk, idx + 1, -1)[..., 0]
    input_wk = jnp.take_along_axis(W, idx, -1)[..., 0]
    input_hk = jnp.take_along_axis(H, idx, -1)[..., 0]
    input_sk = jnp.take_along_axis(sk, idx, -1)[..., 0]

    if inverse:
        # [1] Appendix A.3
        # quadratic formula coefficients
        a = (input_hk) * (input_sk - input_dk) + (masked_inputs - input_yk) * (
            input_dkp1 + input_dk - 2 * input_sk
        )
        b = (input_hk) * input_dk - (masked_inputs - input_yk) * (
            input_dkp1 + input_dk - 2 * input_sk
        )
        c = -input_sk * (masked_inputs - input_yk)

        relx = 2 * c / (-b - jnp.sqrt(b**2 - 4 * a * c))
        outputs = relx * input_wk + input_xk
        # if not periodic, replace out-of-bounds values with original values
        if not periodic:
            outputs = jnp.where(out_of_bounds, inputs, outputs)

        # [1] Appendix A.2
        # calculate the log determinant
        dnum = (
            input_dkp1 * relx**2
            + 2 * input_sk * relx * (1 - relx)
            + input_dk * (1 - relx) ** 2
        )
        dden = input_sk + (input_dkp1 + input_dk - 2 * input_sk) * relx * (
            1 - relx
        )
        log_det = 2 * jnp.log(input_sk) + jnp.log(dnum) - 2 * jnp.log(dden)
        # if not periodic, replace log_det for out-of-bounds values = 0
        if not periodic:
            log_det = jnp.where(out_of_bounds, 0, log_det)
        log_det = log_det.sum(axis=1)

        return outputs, -log_det

    else:
        # [1] Appendix A.1
        # calculate spline
        relx = (masked_inputs - input_xk) / input_wk
        num = input_hk * (input_sk * relx**2 + input_dk * relx * (1 - relx))
        den = input_sk + (input_dkp1 + input_dk - 2 * input_sk) * relx * (
            1 - relx
        )
        outputs = input_yk + num / den
        # if not periodic, replace out-of-bounds values with original values
        if not periodic:
            outputs = jnp.where(out_of_bounds, inputs, outputs)

        # [1] Appendix A.2
        # calculate the log determinant
        dnum = (
            input_dkp1 * relx**2
            + 2 * input_sk * relx * (1 - relx)
            + input_dk * (1 - relx) ** 2
        )
        dden = input_sk + (input_dkp1 + input_dk - 2 * input_sk) * relx * (
            1 - relx
        )
        log_det = 2 * jnp.log(input_sk) + jnp.log(dnum) - 2 * jnp.log(dden)
        # if not periodic, replace log_det for out-of-bounds values = 0
        if not periodic:
            log_det = jnp.where(out_of_bounds, 0, log_det)
        log_det = log_det.sum(axis=1)

        return outputs, log_det

build_bijector_from_info(info)

Build a Bijector from a Bijector_Info object

Source code in pzflow/utils.py
11
12
13
14
15
16
17
18
19
def build_bijector_from_info(info: tuple) -> tuple:
    """Build a Bijector from a Bijector_Info object"""

    # recurse through chains
    if info[0] == "Chain":
        return bijectors.Chain(*(build_bijector_from_info(i) for i in info[1]))
    # build individual bijector from name and parameters
    else:
        return getattr(bijectors, info[0])(*info[1])

gaussian_error_model(key, X, Xerr, nsamples)

Default Gaussian error model were X are the means and Xerr are the stds.

Source code in pzflow/utils.py
52
53
54
55
56
57
58
59
60
61
def gaussian_error_model(
    key, X: jnp.ndarray, Xerr: jnp.ndarray, nsamples: int
) -> jnp.ndarray:
    """
    Default Gaussian error model were X are the means and Xerr are the stds.
    """

    eps = random.normal(key, shape=(X.shape[0], nsamples, X.shape[1]))

    return X[:, None, :] + eps * Xerr[:, None, :]

sub_diag_indices(inputs)

Return indices for diagonal of 2D blocks in 3D array

Source code in pzflow/utils.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def sub_diag_indices(
    inputs: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """Return indices for diagonal of 2D blocks in 3D array"""
    if inputs.ndim != 3:
        raise ValueError("Input must be a 3D array.")
    nblocks = inputs.shape[0]
    ndiag = min(inputs.shape[1], inputs.shape[2])
    idx = (
        jnp.repeat(jnp.arange(nblocks), ndiag),
        jnp.tile(jnp.arange(ndiag), nblocks),
        jnp.tile(jnp.arange(ndiag), nblocks),
    )
    return idx