.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/example01.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_example01.py: Plot Template Estimator. =========================== Plotting Template Estimator =========================== An example plot for using BDE estimator. .. GENERATED FROM PYTHON SOURCE LINES 9-51 .. code-block:: Python import numpy as np from matplotlib import pyplot as plt from optax import adam from bde.ml.loss import GaussianNLLLoss from bde.ml.models import BDEEstimator, FullyConnectedModule est = BDEEstimator( model_class=FullyConnectedModule, model_kwargs={ "n_output_params": 2, "layer_sizes": [12, 12], }, # No hidden layers by default n_chains=20, # 1 by default n_samples=20, # 1 by default chain_len=500, # 1 by default warmup=100, # 1 by default optimizer_class=adam, optimizer_kwargs={ "learning_rate": 1e-3, }, loss=GaussianNLLLoss(), batch_size=128, # 1 by default epochs=120, # 1 by default metrics=None, validation_size=None, seed=42, ) x = np.random.rand(256, 1).astype(np.float32) x.sort() y = x + 0.2 est = est.fit(x, y[..., 0]) y_pred = est.predict(x) plt.plot(x, y, "k", alpha=0.4, label="True") plt.scatter(x, y_pred, c="crimson", label="Prediction") plt.legend() plt.xlabel("x") plt.ylabel("y") plt.show() .. _sphx_glr_download_auto_examples_example01.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: example01.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: example01.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: example01.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_