Skip to content

flowEnsemble

Define FlowEnsemble object that holds an ensemble of normalizing flows.

FlowEnsemble

An ensemble of normalizing flows.

Attributes:

Name Type Description
data_columns tuple

List of DataFrame columns that the flows expect/produce.

conditional_columns tuple

List of DataFrame columns on which the flows are conditioned.

latent distributions.LatentDist

The latent distribution of the normalizing flows. Has it's own sample and log_prob methods.

data_error_model Callable

The error model for the data variables. See the docstring of init for more details.

condition_error_model Callable

The error model for the conditional variables. See the docstring of init for more details.

info Any

Object containing any kind of info included with the ensemble. Often Reverse the data the flows are trained on.

Source code in pzflow/flowEnsemble.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
class FlowEnsemble:
    """An ensemble of normalizing flows.

    Attributes
    ----------
    data_columns : tuple
        List of DataFrame columns that the flows expect/produce.
    conditional_columns : tuple
        List of DataFrame columns on which the flows are conditioned.
    latent: distributions.LatentDist
        The latent distribution of the normalizing flows.
        Has it's own sample and log_prob methods.
    data_error_model : Callable
        The error model for the data variables. See the docstring of
        __init__ for more details.
    condition_error_model : Callable
        The error model for the conditional variables. See the docstring
        of __init__ for more details.
    info : Any
        Object containing any kind of info included with the ensemble.
        Often Reverse the data the flows are trained on.
    """

    def __init__(
        self,
        data_columns: Sequence[str] = None,
        bijector: Tuple[InitFunction, Bijector_Info] = None,
        latent: distributions.LatentDist = None,
        conditional_columns: Sequence[str] = None,
        data_error_model: Callable = None,
        condition_error_model: Callable = None,
        autoscale_conditions: bool = True,
        N: int = 1,
        info: Any = None,
        file: str = None,
    ) -> None:
        """Instantiate an ensemble of normalizing flows.

        Note that while all of the init parameters are technically optional,
        you must provide either data_columns and bijector OR file.
        In addition, if a file is provided, all other parameters must be None.

        Parameters
        ----------
        data_columns : Sequence[str]; optional
            Tuple, list, or other container of column names.
            These are the columns the flows expect/produce in DataFrames.
        bijector : Bijector Call; optional
            A Bijector call that consists of the bijector InitFunction that
            initializes the bijector and the tuple of Bijector Info.
            Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.
            If not provided, the bijector can be set later using
            flow.set_bijector, or by calling flow.train, in which case the
            default bijector will be used. The default bijector is
            ShiftBounds -> RollingSplineCoupling, where the range of shift
            bounds is learned from the training data, and the dimensions of
            RollingSplineCoupling is inferred. The default bijector assumes
            that the latent has support [-5, 5] for every dimension.
        latent : distributions.LatentDist; optional
            The latent distribution for the normalizing flow. Can be any of
            the distributions from pzflow.distributions. If not provided,
            a uniform distribution is used with input_dim = len(data_columns),
            and B=5.
        conditional_columns : Sequence[str]; optional
            Names of columns on which to condition the normalizing flows.
        data_error_model : Callable; optional
            A callable that defines the error model for data variables.
            data_error_model must take key, X, Xerr, nsamples as arguments:
                - key is a jax rng key, e.g. jax.random.PRNGKey(0)
                - X is 2D array of data variables, where the order of variables
                    matches the order of the columns in data_columns
                - Xerr is the corresponding 2D array of errors
                - nsamples is number of samples to draw from error distribution
            data_error_model must return an array of samples with the shape
            (X.shape[0], nsamples, X.shape[1]).
            If data_error_model is not provided, Gaussian error model assumed.
        condition_error_model : Callable; optional
            A callable that defines the error model for conditional variables.
            condition_error_model must take key, X, Xerr, nsamples, where:
                - key is a jax rng key, e.g. jax.random.PRNGKey(0)
                - X is 2D array of conditional variables, where the order of
                    variables matches order of columns in conditional_columns
                - Xerr is the corresponding 2D array of errors
                - nsamples is number of samples to draw from error distribution
            condition_error_model must return array of samples with shape
            (X.shape[0], nsamples, X.shape[1]).
            If condition_error_model is not provided, Gaussian error model
            assumed.
        autoscale_conditions : bool; default=True
            Sets whether or not conditions are automatically standard scaled
            when passed to a conditional flow. I recommend you leave as True.
        N : int; default=1
            The number of flows in the ensemble.
        info : Any; optional
            An object to attach to the info attribute.
        file : str; optional
            Path to file from which to load a pretrained flow ensemble.
            If a file is provided, all other parameters must be None.
        """

        # validate parameters
        if data_columns is None and file is None:
            raise ValueError("You must provide data_columns OR file.")
        if file is not None and any(
            (
                data_columns is not None,
                bijector is not None,
                conditional_columns is not None,
                latent is not None,
                data_error_model is not None,
                condition_error_model is not None,
                info is not None,
            )
        ):
            raise ValueError(
                "If providing a file, please do not provide any other parameters."
            )

        # if file is provided, load everything from the file
        if file is not None:
            # load the file
            with open(file, "rb") as handle:
                save_dict = pickle.load(handle)

            # make sure the saved file is for this class
            c = save_dict.pop("class")
            if c != self.__class__.__name__:
                raise TypeError(
                    f"This save file isn't a {self.__class__.__name__}. It is a {c}."
                )

            # load the ensemble from the dictionary
            self._ensemble = {
                name: Flow(_dictionary=flow_dict)
                for name, flow_dict in save_dict["ensemble"].items()
            }
            # load the metadata
            self.data_columns = save_dict["data_columns"]
            self.conditional_columns = save_dict["conditional_columns"]
            self.data_error_model = save_dict["data_error_model"]
            self.condition_error_model = save_dict["condition_error_model"]
            self.info = save_dict["info"]

            self._latent_info = save_dict["latent_info"]
            self.latent = getattr(distributions, self._latent_info[0])(
                *self._latent_info[1]
            )

        # otherwise create a new ensemble from the provided parameters
        else:
            # save the ensemble of flows
            self._ensemble = {
                f"Flow {i}": Flow(
                    data_columns=data_columns,
                    bijector=bijector,
                    conditional_columns=conditional_columns,
                    latent=latent,
                    data_error_model=data_error_model,
                    condition_error_model=condition_error_model,
                    autoscale_conditions=autoscale_conditions,
                    seed=i,
                    info=f"Flow {i}",
                )
                for i in range(N)
            }
            # save the metadata
            self.data_columns = data_columns
            self.conditional_columns = conditional_columns
            self.latent = self._ensemble["Flow 0"].latent
            self.data_error_model = data_error_model
            self.condition_error_model = condition_error_model
            self.info = info

    def log_prob(
        self,
        inputs: pd.DataFrame,
        err_samples: int = None,
        seed: int = None,
        returnEnsemble: bool = False,
    ) -> jnp.ndarray:
        """Calculates log probability density of inputs.

        Parameters
        ----------
        inputs : pd.DataFrame
            Input data for which log probability density is calculated.
            Every column in self.data_columns must be present.
            If self.conditional_columns is not None, those must be present
            as well. If other columns are present, they are ignored.
        err_samples : int; default=None
            Number of samples from the error distribution to average over for
            the log_prob calculation. If provided, Gaussian errors are assumed,
            and method will look for error columns in `inputs`. Error columns
            must end in `_err`. E.g. the error column for the variable `u` must
            be `u_err`. Zero error assumed for any missing error columns.
        seed : int; default=None
            Random seed for drawing the samples with Gaussian errors.
        returnEnsemble : bool; default=False
            If True, returns log_prob for each flow in the ensemble as an
            array of shape (inputs.shape[0], N flows in ensemble).
            If False, the prob is averaged over the flows in the ensemble,
            and the log of this average is returned as an array of shape
            (inputs.shape[0],)

        Returns
        -------
        jnp.ndarray
            For shape, see returnEnsemble description above.
        """

        # calculate log_prob for each flow in the ensemble
        ensemble = jnp.array(
            [
                flow.log_prob(inputs, err_samples, seed)
                for flow in self._ensemble.values()
            ]
        )

        # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
        ensemble = jnp.rollaxis(ensemble, axis=1)

        if returnEnsemble:
            # return the ensemble of log_probs
            return ensemble
        else:
            # return mean over ensemble
            # note we return log(mean prob) instead of just mean log_prob
            return jnp.log(jnp.exp(ensemble).mean(axis=1))

    def posterior(
        self,
        inputs: pd.DataFrame,
        column: str,
        grid: jnp.ndarray,
        marg_rules: dict = None,
        normalize: bool = True,
        err_samples: int = None,
        seed: int = None,
        batch_size: int = None,
        returnEnsemble: bool = False,
        nan_to_zero: bool = True,
    ) -> jnp.ndarray:
        """Calculates posterior distributions for the provided column.

        Calculates the conditional posterior distribution, assuming the
        data values in the other columns of the DataFrame.

        Parameters
        ----------
        inputs : pd.DataFrame
            Data on which the posterior distributions are conditioned.
            Must have columns matching self.data_columns, *except*
            for the column specified for the posterior (see below).
        column : str
            Name of the column for which the posterior distribution
            is calculated. Must be one of the columns in self.data_columns.
            However, whether or not this column is one of the columns in
            `inputs` is irrelevant.
        grid : jnp.ndarray
            Grid on which to calculate the posterior.
        marg_rules : dict; optional
            Dictionary with rules for marginalizing over missing variables.
            The dictionary must contain the key "flag", which gives the flag
            that indicates a missing value. E.g. if missing values are given
            the value 99, the dictionary should contain {"flag": 99}.
            The dictionary must also contain {"name": callable} for any
            variables that will need to be marginalized over, where name is
            the name of the variable, and callable is a callable that takes
            the row of variables nad returns a grid over which to marginalize
            the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}.
            Note: the callable for a given name must *always* return an array
            of the same length, regardless of the input row.
        normalize : boolean; default=True
            Whether to normalize the posterior so that it integrates to 1.
        err_samples : int; default=None
            Number of samples from the error distribution to average over for
            the posterior calculation. If provided, Gaussian errors are assumed,
            and method will look for error columns in `inputs`. Error columns
            must end in `_err`. E.g. the error column for the variable `u` must
            be `u_err`. Zero error assumed for any missing error columns.
        seed : int; default=None
            Random seed for drawing the samples with Gaussian errors.
        batch_size : int; default=None
            Size of batches in which to calculate posteriors. If None, all
            posteriors are calculated simultaneously. Simultaneous calculation
            is faster, but memory intensive for large data sets.
        returnEnsemble : bool; default=False
            If True, returns posterior for each flow in the ensemble as an
            array of shape (inputs.shape[0], N flows in ensemble, grid.size).
            If False, the posterior is averaged over the flows in the ensemble,
            and returned as an array of shape (inputs.shape[0], grid.size)
        nan_to_zero : bool; default=True
            Whether to convert NaN's to zero probability in the final pdfs.

        Returns
        -------
        jnp.ndarray
            For shape, see returnEnsemble description above.
        """

        # calculate posterior for each flow in the ensemble
        ensemble = jnp.array(
            [
                flow.posterior(
                    inputs=inputs,
                    column=column,
                    grid=grid,
                    marg_rules=marg_rules,
                    err_samples=err_samples,
                    seed=seed,
                    batch_size=batch_size,
                    normalize=False,
                    nan_to_zero=nan_to_zero,
                )
                for flow in self._ensemble.values()
            ]
        )

        # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
        ensemble = jnp.rollaxis(ensemble, axis=1)

        if returnEnsemble:
            # return the ensemble of posteriors
            if normalize:
                ensemble = ensemble.reshape(-1, grid.size)
                ensemble = ensemble / jnp.trapz(y=ensemble, x=grid).reshape(
                    -1, 1
                )
                ensemble = ensemble.reshape(inputs.shape[0], -1, grid.size)
            return ensemble
        else:
            # return mean over ensemble
            pdfs = ensemble.mean(axis=1)
            if normalize:
                pdfs = pdfs / jnp.trapz(y=pdfs, x=grid).reshape(-1, 1)
            return pdfs

    def sample(
        self,
        nsamples: int = 1,
        conditions: pd.DataFrame = None,
        save_conditions: bool = True,
        seed: int = None,
        returnEnsemble: bool = False,
    ) -> pd.DataFrame:
        """Returns samples from the ensemble.

        Parameters
        ----------
        nsamples : int; default=1
            The number of samples to be returned, either overall or per flow
            in the ensemble (see returnEnsemble below).
        conditions : pd.DataFrame; optional
            If this is a conditional flow, you must pass conditions for
            each sample. nsamples will be drawn for each row in conditions.
        save_conditions : bool; default=True
            If true, conditions will be saved in the DataFrame of samples
            that is returned.
        seed : int; optional
            Sets the random seed for the samples.
        returnEnsemble : bool; default=False
            If True, nsamples is drawn from each flow in the ensemble.
            If False, nsamples are drawn uniformly from the flows in the ensemble.

        Returns
        -------
        pd.DataFrame
            Pandas DataFrame of samples.
        """

        if returnEnsemble:
            # return nsamples for each flow in the ensemble
            return pd.concat(
                [
                    flow.sample(nsamples, conditions, save_conditions, seed)
                    for flow in self._ensemble.values()
                ],
                keys=self._ensemble.keys(),
            )
        else:
            # if this isn't a conditional flow, sampling is straightforward
            if conditions is None:
                # return nsamples drawn uniformly from the flows in the ensemble
                N = int(jnp.ceil(nsamples / len(self._ensemble)))
                samples = pd.concat(
                    [
                        flow.sample(N, conditions, save_conditions, seed)
                        for flow in self._ensemble.values()
                    ]
                )
                return samples.sample(nsamples, random_state=seed).reset_index(
                    drop=True
                )
            # if this is a conditional flow, it's a little more complicated...
            else:
                # if nsamples > 1, we duplicate the rows of the conditions
                if nsamples > 1:
                    conditions = pd.concat([conditions] * nsamples)

                # now the main sampling algorithm
                seed = np.random.randint(1e18) if seed is None else seed
                # if we are drawing more samples than the number of flows in
                # the ensemble, then we will shuffle the conditions and randomly
                # assign them to one of the constituent flows
                if conditions.shape[0] > len(self._ensemble):
                    # shuffle the conditions
                    conditions_shuffled = conditions.sample(
                        frac=1.0, random_state=int(seed / 1e9)
                    )
                    # split conditions into ~equal sized chunks
                    chunks = np.array_split(
                        conditions_shuffled, len(self._ensemble)
                    )
                    # shuffle the chunks
                    chunks = [
                        chunks[i]
                        for i in random.permutation(
                            random.PRNGKey(seed), jnp.arange(len(chunks))
                        )
                    ]
                    # sample from each flow, and return all the samples
                    return pd.concat(
                        [
                            flow.sample(
                                1, chunk, save_conditions, seed
                            ).set_index(chunk.index)
                            for flow, chunk in zip(
                                self._ensemble.values(), chunks
                            )
                        ]
                    ).sort_index()
                # however, if there are more flows in the ensemble than samples
                # being drawn, then we will randomly select flows for each condition
                else:
                    rng = np.random.default_rng(seed)
                    # randomly select a flow to sample from for each condition
                    flows = rng.choice(
                        list(self._ensemble.values()),
                        size=conditions.shape[0],
                        replace=True,
                    )
                    # sample from each flow and return all the samples together
                    seeds = rng.integers(1e18, size=conditions.shape[0])
                    return pd.concat(
                        [
                            flow.sample(
                                1,
                                conditions[i : i + 1],
                                save_conditions,
                                new_seed,
                            )
                            for i, (flow, new_seed) in enumerate(
                                zip(flows, seeds)
                            )
                        ],
                    ).set_index(conditions.index)

    def save(self, file: str) -> None:
        """Saves the ensemble to a file.

        Pickles the ensemble and saves it to a file that can be passed as
        the `file` argument during flow instantiation.

        WARNING: Currently, this method only works for bijectors that are
        implemented in the `bijectors` module. If you want to save a flow
        with a custom bijector, you either need to add the bijector to that
        module, or handle the saving and loading on your end.

        Parameters
        ----------
        file : str
            Path to where the ensemble will be saved.
            Extension `.pkl` will be appended if not already present.
        """
        save_dict = {
            "data_columns": self.data_columns,
            "conditional_columns": self.conditional_columns,
            "latent_info": self.latent.info,
            "data_error_model": self.data_error_model,
            "condition_error_model": self.condition_error_model,
            "info": self.info,
            "class": self.__class__.__name__,
            "ensemble": {
                name: flow._save_dict()
                for name, flow in self._ensemble.items()
            },
        }

        with open(file, "wb") as handle:
            pickle.dump(save_dict, handle, recurse=True)

    def train(
        self,
        inputs: pd.DataFrame,
        val_set: pd.DataFrame = None,
        epochs: int = 50,
        batch_size: int = 1024,
        optimizer: Callable = None,
        loss_fn: Callable = None,
        convolve_errs: bool = False,
        patience: int = None,
        best_params: bool = True,
        seed: int = 0,
        verbose: bool = False,
        progress_bar: bool = False,
    ) -> dict:
        """Trains the normalizing flows on the provided inputs.

        Parameters
        ----------
        inputs : pd.DataFrame
            Data on which to train the normalizing flows.
            Must have columns matching self.data_columns.
        val_set : pd.DataFrame; default=None
            Validation set, of same format as inputs. If provided,
            validation loss will be calculated at the end of each epoch.
        epochs : int; default=50
            Number of epochs to train.
        batch_size : int; default=1024
            Batch size for training.
        optimizer : optax optimizer
            An optimizer from Optax. default = optax.adam(learning_rate=1e-3)
            see https://optax.readthedocs.io/en/latest/index.html for more.
        loss_fn : Callable; optional
            A function to calculate the loss: loss = loss_fn(params, x).
            If not provided, will be -mean(log_prob).
        convolve_errs : bool; default=False
            Whether to draw new data from the error distributions during
            each epoch of training. Method will look for error columns in
            `inputs`. Error columns must end in `_err`. E.g. the error column
            for the variable `u` must be `u_err`. Zero error assumed for
            any missing error columns. The error distribution is set during
            ensemble instantiation.
        patience : int; optional
            Factor that controls early stopping. Training will stop if the
            loss doesn't decrease for this number of epochs.
        best_params : bool; default=True
            Whether to use the params from the epoch with the lowest loss.
            Note if a validation set is provided, the epoch with the lowest
            validation loss is chosen. If False, the params from the final
            epoch are saved.
        seed : int; default=0
            A random seed to control the batching and the (optional)
            error sampling.
        verbose : bool; default=False
            If true, print the training loss every 5% of epochs.
        progress_bar : bool; default=False
            If true, display a tqdm progress bar during training.

        Returns
        -------
        dict
            Dictionary of training losses from every epoch for each flow
            in the ensemble.
        """

        # generate random seeds for each flow
        rng = np.random.default_rng(seed)
        seeds = rng.integers(1e9, size=len(self._ensemble))

        loss_dict = dict()

        for i, (name, flow) in enumerate(self._ensemble.items()):
            if verbose:
                print(name)

            loss_dict[name] = flow.train(
                inputs=inputs,
                val_set=val_set,
                epochs=epochs,
                batch_size=batch_size,
                optimizer=optimizer,
                loss_fn=loss_fn,
                convolve_errs=convolve_errs,
                patience=patience,
                best_params=best_params,
                seed=seeds[i],
                verbose=verbose,
                progress_bar=progress_bar,
            )

        return loss_dict

__init__(data_columns=None, bijector=None, latent=None, conditional_columns=None, data_error_model=None, condition_error_model=None, autoscale_conditions=True, N=1, info=None, file=None)

Instantiate an ensemble of normalizing flows.

Note that while all of the init parameters are technically optional, you must provide either data_columns and bijector OR file. In addition, if a file is provided, all other parameters must be None.

Parameters:

Name Type Description Default
data_columns Sequence[str]

Tuple, list, or other container of column names. These are the columns the flows expect/produce in DataFrames.

None
bijector Bijector Call

A Bijector call that consists of the bijector InitFunction that initializes the bijector and the tuple of Bijector Info. Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc. If not provided, the bijector can be set later using flow.set_bijector, or by calling flow.train, in which case the default bijector will be used. The default bijector is ShiftBounds -> RollingSplineCoupling, where the range of shift bounds is learned from the training data, and the dimensions of RollingSplineCoupling is inferred. The default bijector assumes that the latent has support [-5, 5] for every dimension.

None
latent distributions.LatentDist

The latent distribution for the normalizing flow. Can be any of the distributions from pzflow.distributions. If not provided, a uniform distribution is used with input_dim = len(data_columns), and B=5.

None
conditional_columns Sequence[str]

Names of columns on which to condition the normalizing flows.

None
data_error_model Callable

A callable that defines the error model for data variables. data_error_model must take key, X, Xerr, nsamples as arguments: - key is a jax rng key, e.g. jax.random.PRNGKey(0) - X is 2D array of data variables, where the order of variables matches the order of the columns in data_columns - Xerr is the corresponding 2D array of errors - nsamples is number of samples to draw from error distribution data_error_model must return an array of samples with the shape (X.shape[0], nsamples, X.shape[1]). If data_error_model is not provided, Gaussian error model assumed.

None
condition_error_model Callable

A callable that defines the error model for conditional variables. condition_error_model must take key, X, Xerr, nsamples, where: - key is a jax rng key, e.g. jax.random.PRNGKey(0) - X is 2D array of conditional variables, where the order of variables matches order of columns in conditional_columns - Xerr is the corresponding 2D array of errors - nsamples is number of samples to draw from error distribution condition_error_model must return array of samples with shape (X.shape[0], nsamples, X.shape[1]). If condition_error_model is not provided, Gaussian error model assumed.

None
autoscale_conditions bool

Sets whether or not conditions are automatically standard scaled when passed to a conditional flow. I recommend you leave as True.

True
N int

The number of flows in the ensemble.

1
info Any

An object to attach to the info attribute.

None
file str

Path to file from which to load a pretrained flow ensemble. If a file is provided, all other parameters must be None.

None
Source code in pzflow/flowEnsemble.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def __init__(
    self,
    data_columns: Sequence[str] = None,
    bijector: Tuple[InitFunction, Bijector_Info] = None,
    latent: distributions.LatentDist = None,
    conditional_columns: Sequence[str] = None,
    data_error_model: Callable = None,
    condition_error_model: Callable = None,
    autoscale_conditions: bool = True,
    N: int = 1,
    info: Any = None,
    file: str = None,
) -> None:
    """Instantiate an ensemble of normalizing flows.

    Note that while all of the init parameters are technically optional,
    you must provide either data_columns and bijector OR file.
    In addition, if a file is provided, all other parameters must be None.

    Parameters
    ----------
    data_columns : Sequence[str]; optional
        Tuple, list, or other container of column names.
        These are the columns the flows expect/produce in DataFrames.
    bijector : Bijector Call; optional
        A Bijector call that consists of the bijector InitFunction that
        initializes the bijector and the tuple of Bijector Info.
        Can be the output of any Bijector, e.g. Reverse(), Chain(...), etc.
        If not provided, the bijector can be set later using
        flow.set_bijector, or by calling flow.train, in which case the
        default bijector will be used. The default bijector is
        ShiftBounds -> RollingSplineCoupling, where the range of shift
        bounds is learned from the training data, and the dimensions of
        RollingSplineCoupling is inferred. The default bijector assumes
        that the latent has support [-5, 5] for every dimension.
    latent : distributions.LatentDist; optional
        The latent distribution for the normalizing flow. Can be any of
        the distributions from pzflow.distributions. If not provided,
        a uniform distribution is used with input_dim = len(data_columns),
        and B=5.
    conditional_columns : Sequence[str]; optional
        Names of columns on which to condition the normalizing flows.
    data_error_model : Callable; optional
        A callable that defines the error model for data variables.
        data_error_model must take key, X, Xerr, nsamples as arguments:
            - key is a jax rng key, e.g. jax.random.PRNGKey(0)
            - X is 2D array of data variables, where the order of variables
                matches the order of the columns in data_columns
            - Xerr is the corresponding 2D array of errors
            - nsamples is number of samples to draw from error distribution
        data_error_model must return an array of samples with the shape
        (X.shape[0], nsamples, X.shape[1]).
        If data_error_model is not provided, Gaussian error model assumed.
    condition_error_model : Callable; optional
        A callable that defines the error model for conditional variables.
        condition_error_model must take key, X, Xerr, nsamples, where:
            - key is a jax rng key, e.g. jax.random.PRNGKey(0)
            - X is 2D array of conditional variables, where the order of
                variables matches order of columns in conditional_columns
            - Xerr is the corresponding 2D array of errors
            - nsamples is number of samples to draw from error distribution
        condition_error_model must return array of samples with shape
        (X.shape[0], nsamples, X.shape[1]).
        If condition_error_model is not provided, Gaussian error model
        assumed.
    autoscale_conditions : bool; default=True
        Sets whether or not conditions are automatically standard scaled
        when passed to a conditional flow. I recommend you leave as True.
    N : int; default=1
        The number of flows in the ensemble.
    info : Any; optional
        An object to attach to the info attribute.
    file : str; optional
        Path to file from which to load a pretrained flow ensemble.
        If a file is provided, all other parameters must be None.
    """

    # validate parameters
    if data_columns is None and file is None:
        raise ValueError("You must provide data_columns OR file.")
    if file is not None and any(
        (
            data_columns is not None,
            bijector is not None,
            conditional_columns is not None,
            latent is not None,
            data_error_model is not None,
            condition_error_model is not None,
            info is not None,
        )
    ):
        raise ValueError(
            "If providing a file, please do not provide any other parameters."
        )

    # if file is provided, load everything from the file
    if file is not None:
        # load the file
        with open(file, "rb") as handle:
            save_dict = pickle.load(handle)

        # make sure the saved file is for this class
        c = save_dict.pop("class")
        if c != self.__class__.__name__:
            raise TypeError(
                f"This save file isn't a {self.__class__.__name__}. It is a {c}."
            )

        # load the ensemble from the dictionary
        self._ensemble = {
            name: Flow(_dictionary=flow_dict)
            for name, flow_dict in save_dict["ensemble"].items()
        }
        # load the metadata
        self.data_columns = save_dict["data_columns"]
        self.conditional_columns = save_dict["conditional_columns"]
        self.data_error_model = save_dict["data_error_model"]
        self.condition_error_model = save_dict["condition_error_model"]
        self.info = save_dict["info"]

        self._latent_info = save_dict["latent_info"]
        self.latent = getattr(distributions, self._latent_info[0])(
            *self._latent_info[1]
        )

    # otherwise create a new ensemble from the provided parameters
    else:
        # save the ensemble of flows
        self._ensemble = {
            f"Flow {i}": Flow(
                data_columns=data_columns,
                bijector=bijector,
                conditional_columns=conditional_columns,
                latent=latent,
                data_error_model=data_error_model,
                condition_error_model=condition_error_model,
                autoscale_conditions=autoscale_conditions,
                seed=i,
                info=f"Flow {i}",
            )
            for i in range(N)
        }
        # save the metadata
        self.data_columns = data_columns
        self.conditional_columns = conditional_columns
        self.latent = self._ensemble["Flow 0"].latent
        self.data_error_model = data_error_model
        self.condition_error_model = condition_error_model
        self.info = info

log_prob(inputs, err_samples=None, seed=None, returnEnsemble=False)

Calculates log probability density of inputs.

Parameters:

Name Type Description Default
inputs pd.DataFrame

Input data for which log probability density is calculated. Every column in self.data_columns must be present. If self.conditional_columns is not None, those must be present as well. If other columns are present, they are ignored.

required
err_samples int

Number of samples from the error distribution to average over for the log_prob calculation. If provided, Gaussian errors are assumed, and method will look for error columns in inputs. Error columns must end in _err. E.g. the error column for the variable u must be u_err. Zero error assumed for any missing error columns.

None
seed int

Random seed for drawing the samples with Gaussian errors.

None
returnEnsemble bool

If True, returns log_prob for each flow in the ensemble as an array of shape (inputs.shape[0], N flows in ensemble). If False, the prob is averaged over the flows in the ensemble, and the log of this average is returned as an array of shape (inputs.shape[0],)

False

Returns:

Type Description
jnp.ndarray

For shape, see returnEnsemble description above.

Source code in pzflow/flowEnsemble.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
def log_prob(
    self,
    inputs: pd.DataFrame,
    err_samples: int = None,
    seed: int = None,
    returnEnsemble: bool = False,
) -> jnp.ndarray:
    """Calculates log probability density of inputs.

    Parameters
    ----------
    inputs : pd.DataFrame
        Input data for which log probability density is calculated.
        Every column in self.data_columns must be present.
        If self.conditional_columns is not None, those must be present
        as well. If other columns are present, they are ignored.
    err_samples : int; default=None
        Number of samples from the error distribution to average over for
        the log_prob calculation. If provided, Gaussian errors are assumed,
        and method will look for error columns in `inputs`. Error columns
        must end in `_err`. E.g. the error column for the variable `u` must
        be `u_err`. Zero error assumed for any missing error columns.
    seed : int; default=None
        Random seed for drawing the samples with Gaussian errors.
    returnEnsemble : bool; default=False
        If True, returns log_prob for each flow in the ensemble as an
        array of shape (inputs.shape[0], N flows in ensemble).
        If False, the prob is averaged over the flows in the ensemble,
        and the log of this average is returned as an array of shape
        (inputs.shape[0],)

    Returns
    -------
    jnp.ndarray
        For shape, see returnEnsemble description above.
    """

    # calculate log_prob for each flow in the ensemble
    ensemble = jnp.array(
        [
            flow.log_prob(inputs, err_samples, seed)
            for flow in self._ensemble.values()
        ]
    )

    # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
    ensemble = jnp.rollaxis(ensemble, axis=1)

    if returnEnsemble:
        # return the ensemble of log_probs
        return ensemble
    else:
        # return mean over ensemble
        # note we return log(mean prob) instead of just mean log_prob
        return jnp.log(jnp.exp(ensemble).mean(axis=1))

posterior(inputs, column, grid, marg_rules=None, normalize=True, err_samples=None, seed=None, batch_size=None, returnEnsemble=False, nan_to_zero=True)

Calculates posterior distributions for the provided column.

Calculates the conditional posterior distribution, assuming the data values in the other columns of the DataFrame.

Parameters:

Name Type Description Default
inputs pd.DataFrame

Data on which the posterior distributions are conditioned. Must have columns matching self.data_columns, except for the column specified for the posterior (see below).

required
column str

Name of the column for which the posterior distribution is calculated. Must be one of the columns in self.data_columns. However, whether or not this column is one of the columns in inputs is irrelevant.

required
grid jnp.ndarray

Grid on which to calculate the posterior.

required
marg_rules dict

Dictionary with rules for marginalizing over missing variables. The dictionary must contain the key "flag", which gives the flag that indicates a missing value. E.g. if missing values are given the value 99, the dictionary should contain {"flag": 99}. The dictionary must also contain {"name": callable} for any variables that will need to be marginalized over, where name is the name of the variable, and callable is a callable that takes the row of variables nad returns a grid over which to marginalize the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}. Note: the callable for a given name must always return an array of the same length, regardless of the input row.

None
normalize boolean

Whether to normalize the posterior so that it integrates to 1.

True
err_samples int

Number of samples from the error distribution to average over for the posterior calculation. If provided, Gaussian errors are assumed, and method will look for error columns in inputs. Error columns must end in _err. E.g. the error column for the variable u must be u_err. Zero error assumed for any missing error columns.

None
seed int

Random seed for drawing the samples with Gaussian errors.

None
batch_size int

Size of batches in which to calculate posteriors. If None, all posteriors are calculated simultaneously. Simultaneous calculation is faster, but memory intensive for large data sets.

None
returnEnsemble bool

If True, returns posterior for each flow in the ensemble as an array of shape (inputs.shape[0], N flows in ensemble, grid.size). If False, the posterior is averaged over the flows in the ensemble, and returned as an array of shape (inputs.shape[0], grid.size)

False
nan_to_zero bool

Whether to convert NaN's to zero probability in the final pdfs.

True

Returns:

Type Description
jnp.ndarray

For shape, see returnEnsemble description above.

Source code in pzflow/flowEnsemble.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def posterior(
    self,
    inputs: pd.DataFrame,
    column: str,
    grid: jnp.ndarray,
    marg_rules: dict = None,
    normalize: bool = True,
    err_samples: int = None,
    seed: int = None,
    batch_size: int = None,
    returnEnsemble: bool = False,
    nan_to_zero: bool = True,
) -> jnp.ndarray:
    """Calculates posterior distributions for the provided column.

    Calculates the conditional posterior distribution, assuming the
    data values in the other columns of the DataFrame.

    Parameters
    ----------
    inputs : pd.DataFrame
        Data on which the posterior distributions are conditioned.
        Must have columns matching self.data_columns, *except*
        for the column specified for the posterior (see below).
    column : str
        Name of the column for which the posterior distribution
        is calculated. Must be one of the columns in self.data_columns.
        However, whether or not this column is one of the columns in
        `inputs` is irrelevant.
    grid : jnp.ndarray
        Grid on which to calculate the posterior.
    marg_rules : dict; optional
        Dictionary with rules for marginalizing over missing variables.
        The dictionary must contain the key "flag", which gives the flag
        that indicates a missing value. E.g. if missing values are given
        the value 99, the dictionary should contain {"flag": 99}.
        The dictionary must also contain {"name": callable} for any
        variables that will need to be marginalized over, where name is
        the name of the variable, and callable is a callable that takes
        the row of variables nad returns a grid over which to marginalize
        the variable. E.g. {"y": lambda row: jnp.linspace(0, row["x"], 10)}.
        Note: the callable for a given name must *always* return an array
        of the same length, regardless of the input row.
    normalize : boolean; default=True
        Whether to normalize the posterior so that it integrates to 1.
    err_samples : int; default=None
        Number of samples from the error distribution to average over for
        the posterior calculation. If provided, Gaussian errors are assumed,
        and method will look for error columns in `inputs`. Error columns
        must end in `_err`. E.g. the error column for the variable `u` must
        be `u_err`. Zero error assumed for any missing error columns.
    seed : int; default=None
        Random seed for drawing the samples with Gaussian errors.
    batch_size : int; default=None
        Size of batches in which to calculate posteriors. If None, all
        posteriors are calculated simultaneously. Simultaneous calculation
        is faster, but memory intensive for large data sets.
    returnEnsemble : bool; default=False
        If True, returns posterior for each flow in the ensemble as an
        array of shape (inputs.shape[0], N flows in ensemble, grid.size).
        If False, the posterior is averaged over the flows in the ensemble,
        and returned as an array of shape (inputs.shape[0], grid.size)
    nan_to_zero : bool; default=True
        Whether to convert NaN's to zero probability in the final pdfs.

    Returns
    -------
    jnp.ndarray
        For shape, see returnEnsemble description above.
    """

    # calculate posterior for each flow in the ensemble
    ensemble = jnp.array(
        [
            flow.posterior(
                inputs=inputs,
                column=column,
                grid=grid,
                marg_rules=marg_rules,
                err_samples=err_samples,
                seed=seed,
                batch_size=batch_size,
                normalize=False,
                nan_to_zero=nan_to_zero,
            )
            for flow in self._ensemble.values()
        ]
    )

    # re-arrange so that (axis 0, axis 1) = (inputs, flows in ensemble)
    ensemble = jnp.rollaxis(ensemble, axis=1)

    if returnEnsemble:
        # return the ensemble of posteriors
        if normalize:
            ensemble = ensemble.reshape(-1, grid.size)
            ensemble = ensemble / jnp.trapz(y=ensemble, x=grid).reshape(
                -1, 1
            )
            ensemble = ensemble.reshape(inputs.shape[0], -1, grid.size)
        return ensemble
    else:
        # return mean over ensemble
        pdfs = ensemble.mean(axis=1)
        if normalize:
            pdfs = pdfs / jnp.trapz(y=pdfs, x=grid).reshape(-1, 1)
        return pdfs

sample(nsamples=1, conditions=None, save_conditions=True, seed=None, returnEnsemble=False)

Returns samples from the ensemble.

Parameters:

Name Type Description Default
nsamples int

The number of samples to be returned, either overall or per flow in the ensemble (see returnEnsemble below).

1
conditions pd.DataFrame

If this is a conditional flow, you must pass conditions for each sample. nsamples will be drawn for each row in conditions.

None
save_conditions bool

If true, conditions will be saved in the DataFrame of samples that is returned.

True
seed int

Sets the random seed for the samples.

None
returnEnsemble bool

If True, nsamples is drawn from each flow in the ensemble. If False, nsamples are drawn uniformly from the flows in the ensemble.

False

Returns:

Type Description
pd.DataFrame

Pandas DataFrame of samples.

Source code in pzflow/flowEnsemble.py
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
def sample(
    self,
    nsamples: int = 1,
    conditions: pd.DataFrame = None,
    save_conditions: bool = True,
    seed: int = None,
    returnEnsemble: bool = False,
) -> pd.DataFrame:
    """Returns samples from the ensemble.

    Parameters
    ----------
    nsamples : int; default=1
        The number of samples to be returned, either overall or per flow
        in the ensemble (see returnEnsemble below).
    conditions : pd.DataFrame; optional
        If this is a conditional flow, you must pass conditions for
        each sample. nsamples will be drawn for each row in conditions.
    save_conditions : bool; default=True
        If true, conditions will be saved in the DataFrame of samples
        that is returned.
    seed : int; optional
        Sets the random seed for the samples.
    returnEnsemble : bool; default=False
        If True, nsamples is drawn from each flow in the ensemble.
        If False, nsamples are drawn uniformly from the flows in the ensemble.

    Returns
    -------
    pd.DataFrame
        Pandas DataFrame of samples.
    """

    if returnEnsemble:
        # return nsamples for each flow in the ensemble
        return pd.concat(
            [
                flow.sample(nsamples, conditions, save_conditions, seed)
                for flow in self._ensemble.values()
            ],
            keys=self._ensemble.keys(),
        )
    else:
        # if this isn't a conditional flow, sampling is straightforward
        if conditions is None:
            # return nsamples drawn uniformly from the flows in the ensemble
            N = int(jnp.ceil(nsamples / len(self._ensemble)))
            samples = pd.concat(
                [
                    flow.sample(N, conditions, save_conditions, seed)
                    for flow in self._ensemble.values()
                ]
            )
            return samples.sample(nsamples, random_state=seed).reset_index(
                drop=True
            )
        # if this is a conditional flow, it's a little more complicated...
        else:
            # if nsamples > 1, we duplicate the rows of the conditions
            if nsamples > 1:
                conditions = pd.concat([conditions] * nsamples)

            # now the main sampling algorithm
            seed = np.random.randint(1e18) if seed is None else seed
            # if we are drawing more samples than the number of flows in
            # the ensemble, then we will shuffle the conditions and randomly
            # assign them to one of the constituent flows
            if conditions.shape[0] > len(self._ensemble):
                # shuffle the conditions
                conditions_shuffled = conditions.sample(
                    frac=1.0, random_state=int(seed / 1e9)
                )
                # split conditions into ~equal sized chunks
                chunks = np.array_split(
                    conditions_shuffled, len(self._ensemble)
                )
                # shuffle the chunks
                chunks = [
                    chunks[i]
                    for i in random.permutation(
                        random.PRNGKey(seed), jnp.arange(len(chunks))
                    )
                ]
                # sample from each flow, and return all the samples
                return pd.concat(
                    [
                        flow.sample(
                            1, chunk, save_conditions, seed
                        ).set_index(chunk.index)
                        for flow, chunk in zip(
                            self._ensemble.values(), chunks
                        )
                    ]
                ).sort_index()
            # however, if there are more flows in the ensemble than samples
            # being drawn, then we will randomly select flows for each condition
            else:
                rng = np.random.default_rng(seed)
                # randomly select a flow to sample from for each condition
                flows = rng.choice(
                    list(self._ensemble.values()),
                    size=conditions.shape[0],
                    replace=True,
                )
                # sample from each flow and return all the samples together
                seeds = rng.integers(1e18, size=conditions.shape[0])
                return pd.concat(
                    [
                        flow.sample(
                            1,
                            conditions[i : i + 1],
                            save_conditions,
                            new_seed,
                        )
                        for i, (flow, new_seed) in enumerate(
                            zip(flows, seeds)
                        )
                    ],
                ).set_index(conditions.index)

save(file)

Saves the ensemble to a file.

Pickles the ensemble and saves it to a file that can be passed as the file argument during flow instantiation.

WARNING: Currently, this method only works for bijectors that are implemented in the bijectors module. If you want to save a flow with a custom bijector, you either need to add the bijector to that module, or handle the saving and loading on your end.

Parameters:

Name Type Description Default
file str

Path to where the ensemble will be saved. Extension .pkl will be appended if not already present.

required
Source code in pzflow/flowEnsemble.py
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
def save(self, file: str) -> None:
    """Saves the ensemble to a file.

    Pickles the ensemble and saves it to a file that can be passed as
    the `file` argument during flow instantiation.

    WARNING: Currently, this method only works for bijectors that are
    implemented in the `bijectors` module. If you want to save a flow
    with a custom bijector, you either need to add the bijector to that
    module, or handle the saving and loading on your end.

    Parameters
    ----------
    file : str
        Path to where the ensemble will be saved.
        Extension `.pkl` will be appended if not already present.
    """
    save_dict = {
        "data_columns": self.data_columns,
        "conditional_columns": self.conditional_columns,
        "latent_info": self.latent.info,
        "data_error_model": self.data_error_model,
        "condition_error_model": self.condition_error_model,
        "info": self.info,
        "class": self.__class__.__name__,
        "ensemble": {
            name: flow._save_dict()
            for name, flow in self._ensemble.items()
        },
    }

    with open(file, "wb") as handle:
        pickle.dump(save_dict, handle, recurse=True)

train(inputs, val_set=None, epochs=50, batch_size=1024, optimizer=None, loss_fn=None, convolve_errs=False, patience=None, best_params=True, seed=0, verbose=False, progress_bar=False)

Trains the normalizing flows on the provided inputs.

Parameters:

Name Type Description Default
inputs pd.DataFrame

Data on which to train the normalizing flows. Must have columns matching self.data_columns.

required
val_set pd.DataFrame

Validation set, of same format as inputs. If provided, validation loss will be calculated at the end of each epoch.

None
epochs int

Number of epochs to train.

50
batch_size int

Batch size for training.

1024
optimizer optax optimizer

An optimizer from Optax. default = optax.adam(learning_rate=1e-3) see https://optax.readthedocs.io/en/latest/index.html for more.

None
loss_fn Callable

A function to calculate the loss: loss = loss_fn(params, x). If not provided, will be -mean(log_prob).

None
convolve_errs bool

Whether to draw new data from the error distributions during each epoch of training. Method will look for error columns in inputs. Error columns must end in _err. E.g. the error column for the variable u must be u_err. Zero error assumed for any missing error columns. The error distribution is set during ensemble instantiation.

False
patience int

Factor that controls early stopping. Training will stop if the loss doesn't decrease for this number of epochs.

None
best_params bool

Whether to use the params from the epoch with the lowest loss. Note if a validation set is provided, the epoch with the lowest validation loss is chosen. If False, the params from the final epoch are saved.

True
seed int

A random seed to control the batching and the (optional) error sampling.

0
verbose bool

If true, print the training loss every 5% of epochs.

False
progress_bar bool

If true, display a tqdm progress bar during training.

False

Returns:

Type Description
dict

Dictionary of training losses from every epoch for each flow in the ensemble.

Source code in pzflow/flowEnsemble.py
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
def train(
    self,
    inputs: pd.DataFrame,
    val_set: pd.DataFrame = None,
    epochs: int = 50,
    batch_size: int = 1024,
    optimizer: Callable = None,
    loss_fn: Callable = None,
    convolve_errs: bool = False,
    patience: int = None,
    best_params: bool = True,
    seed: int = 0,
    verbose: bool = False,
    progress_bar: bool = False,
) -> dict:
    """Trains the normalizing flows on the provided inputs.

    Parameters
    ----------
    inputs : pd.DataFrame
        Data on which to train the normalizing flows.
        Must have columns matching self.data_columns.
    val_set : pd.DataFrame; default=None
        Validation set, of same format as inputs. If provided,
        validation loss will be calculated at the end of each epoch.
    epochs : int; default=50
        Number of epochs to train.
    batch_size : int; default=1024
        Batch size for training.
    optimizer : optax optimizer
        An optimizer from Optax. default = optax.adam(learning_rate=1e-3)
        see https://optax.readthedocs.io/en/latest/index.html for more.
    loss_fn : Callable; optional
        A function to calculate the loss: loss = loss_fn(params, x).
        If not provided, will be -mean(log_prob).
    convolve_errs : bool; default=False
        Whether to draw new data from the error distributions during
        each epoch of training. Method will look for error columns in
        `inputs`. Error columns must end in `_err`. E.g. the error column
        for the variable `u` must be `u_err`. Zero error assumed for
        any missing error columns. The error distribution is set during
        ensemble instantiation.
    patience : int; optional
        Factor that controls early stopping. Training will stop if the
        loss doesn't decrease for this number of epochs.
    best_params : bool; default=True
        Whether to use the params from the epoch with the lowest loss.
        Note if a validation set is provided, the epoch with the lowest
        validation loss is chosen. If False, the params from the final
        epoch are saved.
    seed : int; default=0
        A random seed to control the batching and the (optional)
        error sampling.
    verbose : bool; default=False
        If true, print the training loss every 5% of epochs.
    progress_bar : bool; default=False
        If true, display a tqdm progress bar during training.

    Returns
    -------
    dict
        Dictionary of training losses from every epoch for each flow
        in the ensemble.
    """

    # generate random seeds for each flow
    rng = np.random.default_rng(seed)
    seeds = rng.integers(1e9, size=len(self._ensemble))

    loss_dict = dict()

    for i, (name, flow) in enumerate(self._ensemble.items()):
        if verbose:
            print(name)

        loss_dict[name] = flow.train(
            inputs=inputs,
            val_set=val_set,
            epochs=epochs,
            batch_size=batch_size,
            optimizer=optimizer,
            loss_fn=loss_fn,
            convolve_errs=convolve_errs,
            patience=patience,
            best_params=best_params,
            seed=seeds[i],
            verbose=verbose,
            progress_bar=progress_bar,
        )

    return loss_dict