Skip to content
Snippets Groups Projects
Commit 8f5426df authored by Hu Zhao's avatar Hu Zhao
Browse files

feat: change emulate_and_predict_ln_pxl to approx_ln_pxl

parent 39498bba
Branches
Tags
No related merge requests found
......@@ -243,13 +243,13 @@ class ActiveLearning:
var_samples : numpy array
Variable input samples of `ninit` simulations and `niter`
iterative simulations. 2d array of shape (ninit+niter, ndim).
sim_outputs : numpy array
Outputs of `ninit` and `niter` simulations, corresponding to `data`.
2d array of shape (ninit+niter, len(data)).
ln_pxl_values : numpy array
Natural logarithm values of the product of prior and likelihood
at `ninit` and `niter` simulations.
1d array of shape (ninit+niter,).
sim_outputs : numpy array
Outputs of `ninit` and `niter` simulations, corresponding to `data`.
2d array of shape (ninit+niter, len(data)).
"""
if init_var_samples.shape != (ninit, self.ndim):
raise ValueError("init_var_samples must be of shape (ninit, ndim)")
......@@ -301,32 +301,28 @@ class ActiveLearning:
next_var_sample.reshape(-1), next_sim_output)
ln_pxl_values.append(next_ln_pxl_value)
# train final scalar gasp
self._emulate_ln_pxl(var_samples, np.array(ln_pxl_values))
ln_pxl_values = np.array(ln_pxl_values)
return var_samples, ln_pxl_values, sim_outputs
return var_samples, sim_outputs, ln_pxl_values
@beartype
def emulate_and_predict_ln_pxl(self, x: np.ndarray, var_samples: np.ndarray,
ln_pxl_values: np.ndarray) -> float:
def approx_ln_pxl(self, x: np.ndarray) -> float:
"""
Build a scalar GP emulator for ln_pxl and make prediction at new x.
Approximate ln_pxl value at x based on the trained calar GP emulator.
Parameters
----------
x : numpy array
One variable sample at which ln_pxl is to be approximated. 1d array
of shape (ndim,)
var_samples : numpy array
Samples of variable inputs. 2d array of shape (n, ndim).
ln_pxl_values : numpy array
Natural logarithm values of the product of prior and likelihood
at `var_samples`. 1d array of shape (n,).
Returns
-------
A float value which is the emulator-predicted ln_pxl value at x.
"""
self._emulate_ln_pxl(var_samples, ln_pxl_values)
predict = self._predict_ln_pxl(x)
return float(predict[:,0])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment