%matplotlib inline
%reload_ext autoreload
%autoreload 2
def parallel_calc_nll(s_true:float, b_true:float, s_exp:Tensor, f_s:Tensor, alpha:Tensor,
f_b_nom:Tensor, f_b_up:Tensor, f_b_dw:Tensor) -> Tensor:
r'''Unused
Compute multiple negative log-likelihood for specified parameters. Unused due to difficulty of batch-wise hessians in PyTorch.'''
f_b = interp_shape(alpha, f_b_nom, f_b_up, f_b_dw)
t_exp = (s_exp[:,None]*f_s[None,])+(b_true*f_b)
asimov = (s_true*f_s)+(b_true*f_b_nom)
p = torch.distributions.Poisson(t_exp, False)
return -p.log_prob(asimov).sum(1)
def calc_diag_grad_hesse(nll:Tensor, alpha:Tensor) -> Tuple[Tensor,Tensor]:
r'''Unused
Compute batch-wise gradient and hessian, but only the diagonal elements.'''
grad = autograd.grad(nll, alpha, torch.ones_like(nll, device=nll.device), create_graph=True)[0]
hesse = autograd.grad(grad, alpha, torch.ones_like(alpha, device=nll.device), create_graph=True, retain_graph=True)[0]
alpha.grad=None
return grad, hesse
def calc_diag_profile(f_s:Tensor, f_b_nom:Tensor, f_b_up:Tensor, f_b_dw:Tensor, n:int,
mu_scan:Tensor, true_mu:int, n_steps:int=100, lr:float=0.1, verbose:bool=True) -> Tensor:
r'''Unused
Compute profile likelihood for range of mu values, but only optimise using diagonal hessian elements.'''
alpha = torch.zeros((len(mu_scan),f_b_up.shape[0]), requires_grad=True, device=f_b_nom.device)
f_b_nom = f_b_nom.unsqueeze(0)
get_nll = partialler(parallel_calc_nll, s_true=true_mu, b_true=n-true_mu, s_exp=mu_scan,
f_s=f_s, f_b_nom=f_b_nom, f_b_up=f_b_up, f_b_dw=f_b_dw)
for i in range(n_steps): # Newton optimise nuisances
nll = get_nll(alpha=alpha)
grad, hesse = calc_diag_grad_hesse(nll, alpha)
step = torch.clamp(lr*grad.detach()/(hesse+1e-7), -100, 100)
alpha = alpha-step
return get_nll(alpha=alpha), alpha