Source code for counting

"""
counting.py

Functions for the Back-Propagation Counting method.

This module provides tools to process pixelated silicon detector data, including prior computation,
frame counting using PyTorch optimization, and storage in HDF5 format.
"""
import h5py
import numpy as np

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.patches import Circle

from scipy.optimize import minimize
from scipy.stats import poisson

import torch
import torch.nn.functional as F

# Set the device ('cpu', 'cuda', 'mps')
device = 'cuda'

# -----------------------------------------------------------------------------------
# Helper methods for counting
# -----------------------------------------------------------------------------------
[docs] def compute_conditional_probabilities(lam_grid): """ Compute conditional probabilities P(n >= 2 | n >= 1) for a grid of "lambda" values (average electron hit probabilities). Parameters ---------- lam_grid : ndarray A numpy array of lambda values (Poisson parameter). Returns ------- ndarray A numpy array of conditional probabilities P(n >= 2 | n >= 1). Notes ----- Uses the Poisson distribution to calculate probabilities, with safeguards against division by zero. """ # Compute P(n >= 1) and P(n >= 2) for each lambda in the grid p_at_least_1 = 1 - poisson.cdf(0, lam_grid) # P(n >= 1) = 1 - P(n < 1) p_at_least_2 = 1 - poisson.cdf(1, lam_grid) # P(n >= 2) = 1 - P(n <= 1) # Conditional probability P(n >= 2) / P(n >= 1) # Safeguard division by zero by using np.where to only compute valid divisions conditional_prob = np.where(p_at_least_1 > 0, p_at_least_2 / p_at_least_1, 0) return conditional_prob
[docs] def compute_prior(frames_file, nframes, baseline, gauss_A): """ Compute the "prior" from a set of frames. The prior consists of the average number of electron hits at each pixel in the frame and has the same dimensions as a single frame. Parameters ---------- frames_file : str Path to the HDF5 file containing the frames dataset. nframes : int Number of frames to use for computing the prior. baseline : float Baseline value to subtract from each frame. gauss_A : float Amplitude of the Gaussian profile for a single electron. Returns ------- ndarray The computed "prior" frame. Notes ----- The prior is computed by summing frames, normalizing, and eliminating negative values. """ # Create the "prior" with h5py.File(frames_file, 'r') as f0: data = f0['frames'] # Get all frames and subtract the baseline prior_bls = np.array(data[0:nframes,:,:],dtype=np.float32) - baseline print(f"Prior values, {nframes} frames, shape: {prior_bls.shape}") # Compute the summed frame prior_frame = np.sum(prior_bls,axis=0) print("Summed frame dimensions:",prior_frame.shape) # Divide by the average electron amplitude prior_frame /= gauss_A # Normalize by the number of frames to get the final "prior" prior_frame /= nframes # Eliminate negative values prior_frame[prior_frame < 0] = 0. return prior_frame
[docs] def construct_modeled_frame_pytorch(frame_ct, splash_kernel): """ Construct a frame from the count grid by applying a Gaussian electron strike profile (kernel) via convolution. Parameters ---------- frame_ct : torch.Tensor The count grid tensor (batch of 2D arrays). splash_kernel : torch.Tensor The Gaussian kernel for convolution. Returns ------- torch.Tensor The modeled frame after applying the Gaussian splash. """ # Ensure frame_ct is a float tensor and add batch and channel dimensions frame_ct_tensor = frame_ct.float().unsqueeze(1) # Apply the Gaussian splash across the entire frame using convolution modeled_frame = F.conv2d(frame_ct_tensor, splash_kernel, padding=splash_kernel.shape[-1]//2) # Remove batch and channel dimensions from the output modeled_frame = modeled_frame.squeeze(1) return modeled_frame
[docs] def count_frame_pytorch(frame_bls, frame_ct, gauss_A, gauss_sigma, n_steps_max=5000, loss_lim = 1, min_loss_patience = 10, min_loss_improvement = 0.01): """ Count electrons in a frame using PyTorch optimization. Parameters ---------- frame_bls : ndarray Baseline-subtracted frame data. frame_ct : ndarray Initial guess for the counted frame. gauss_A : float Amplitude of the Gaussian profile. gauss_sigma : float Standard deviation of the Gaussian profile. n_steps_max : int, optional Maximum number of optimization steps (default: 5000). loss_lim : float, optional Loss threshold to stop optimization (default: 1). min_loss_patience : int, optional Number of steps to wait for loss improvement (default: 10). min_loss_improvement : float, optional Minimum relative loss improvement to continue (default: 0.01). Returns ------- tuple - ndarray: The counted frame. - ndarray: The modeled frame, which is the counted frame convoluted with the Gaussian profile (kernel). - ndarray: Loss values at each step. """ # Convert frame and frame_ct to a PyTorch tensor frame_tensor = torch.from_numpy(frame_bls).to(device) frame_ct_tensor = torch.tensor(frame_ct, dtype=torch.float32, device=device, requires_grad=True) # Define the optimizer optimizer = torch.optim.Adam([frame_ct_tensor], lr=0.01) # Define the single-electron Gaussian "splash splash = gaussian_splash_pytorch(gauss_A,gauss_sigma) # Set up the parameters for the iteration n_steps = 0 loss = loss_lim + 1 # Record the loss at each step loss_steps = [] # Boolean for stopping due to loss improvement condition improvement_stop = False while(loss > loss_lim and n_steps < n_steps_max and not improvement_stop): optimizer.zero_grad() # Clear previous gradients # Construct the modeled frame modeled_frame = construct_modeled_frame_pytorch(frame_ct_tensor, splash) # Compute the loss (negative likelihood) loss = torch.sum((frame_tensor - modeled_frame) ** 2) # Compute gradients loss.backward() # Update frame_ct_tensor based on gradients optimizer.step() # Record the loss loss_steps.append(loss.item()) # Check for stopping condition based on improvement if n_steps > min_loss_patience: relative_improvement = (loss_steps[-min_loss_patience] - loss.item()) / loss_steps[-min_loss_patience] if relative_improvement < min_loss_improvement: print(f"* Stopping at step {n_steps} due to small relative improvement ({relative_improvement:.4f})") improvement_stop = True n_steps += 1 if(n_steps >= n_steps_max): print(f"* Stopping at max steps {n_steps} with loss {loss}") if(loss <= loss_lim): print(f"* Stopping due to loss {loss} dropping below lower limit {loss_lim}") if n_steps % 100 == 0: print(f"Step {n_steps}, Loss: {loss.item()}") print(f"Counted in n_steps = {n_steps} with loss = {loss}") loss_steps = np.array(loss_steps) return frame_ct_tensor.cpu().detach().numpy(), modeled_frame.cpu().detach().numpy(), loss_steps
[docs] def frame_to_indices_weights(counted_frames): """ Convert counted frames to lists of linear indices and weights. Parameters ---------- counted_frames : ndarray Batch of 2D counted frames. Returns ------- tuple - list: Linear indices of non-zero pixels for each frame. - list: Weights (counts) at those indices. """ # Get the batch size and frame shape batch_size = counted_frames.shape[0] frame_shape = counted_frames.shape[1:] # Set up lists for the indices and weights all_linear_indices = [] all_weights = [] # Convert the counted frames to lists of indices and weights for i in range(batch_size): frame = counted_frames[i] nonzero_indices = np.nonzero(frame) weights = frame[nonzero_indices] linear_indices = np.ravel_multi_index(nonzero_indices, frame_shape) all_linear_indices.append(linear_indices) all_weights.append(weights) return all_linear_indices, all_weights
[docs] def gaussian_splash_pytorch(A, sigma, size=3): """ Create a Gaussian "splash" kernel for convolution in PyTorch. Parameters ---------- A : float Amplitude of the Gaussian. sigma : float Standard deviation of the Gaussian. size : int, optional Size of the kernel (default: 3). Returns ------- torch.Tensor 4D tensor representing the Gaussian kernel. """ # Compute the Gaussian kernel x = y = torch.arange(0, size, device=device) - (size // 2) X, Y = torch.meshgrid(x, y) gaussian = A * torch.exp(- (X**2 + Y**2) / (2 * sigma**2)) # Reshape to 4D tensor: (out_channels, in_channels, height, width) gaussian_kernel = gaussian.unsqueeze(0).unsqueeze(0) return gaussian_kernel
[docs] def update_counted_data_hdf5(file_path, nframes, batch_start_idx, frames_indices, frames_weights, scan_shape, frame_shape, group_name='electron_events'): """ Update an HDF5 file with counted frame data. Parameters ---------- file_path : str Path to the HDF5 file. nframes : int Total number of frames. batch_start_idx : int Starting index for the current batch. frames_indices : list List of arrays with pixel indices for each frame. frames_weights : list List of arrays with weights for each frame. scan_shape : tuple Shape of the scan grid (Ny, Nx). frame_shape : tuple Shape of each frame (Ny, Nx). group_name : str, optional HDF5 group name (default: 'electron_events'). """ with h5py.File(file_path, 'a') as f: # Open file in append mode # Create or access the group if group_name not in f: grp = f.create_group(group_name) else: grp = f[group_name] # Check if the variable-length datasets exist, and create them if not if 'frames' not in grp: vl_dtype_indices = h5py.special_dtype(vlen=np.dtype('uint32')) vl_dataset_indices = grp.create_dataset("frames", (nframes,), dtype=vl_dtype_indices) # Add frame size attributes to 'frames' dataset vl_dataset_indices.attrs['Nx'] = frame_shape[1] vl_dataset_indices.attrs['Ny'] = frame_shape[0] else: vl_dataset_indices = grp['frames'] if 'weights' not in grp: vl_dtype_weights = h5py.special_dtype(vlen=np.dtype('uint32')) vl_dataset_weights = grp.create_dataset("weights", (nframes,), dtype=vl_dtype_weights) else: vl_dataset_weights = grp['weights'] # Check if scan_positions dataset exists; if not, create it # Note that while scan_shape must be specified as a parameter, scan_positions is # just an np.arange of the total number of frames if 'scan_positions' not in grp: scan_positions_dataset = grp.create_dataset('scan_positions', data=np.arange(nframes)) scan_positions_dataset.attrs['Nx'] = scan_shape[1] scan_positions_dataset.attrs['Ny'] = scan_shape[0] # Assuming the length of frames_indices matches the expected number of frames, # iterate through each and update the datasets for i, (indices, weights) in enumerate(zip(frames_indices, frames_weights)): vl_dataset_indices[batch_start_idx+i] = indices vl_dataset_weights[batch_start_idx+i] = weights
# ----------------------------------------------------------------------------------- # Main counting function # -----------------------------------------------------------------------------------
[docs] def count_frames(frames_file, counted_file, frames_per_batch, th_single_elec, baseline, gauss_A, gauss_sigma, n_steps_max = 5000, loss_per_frame_stop = 1, min_loss_patience = 10, min_loss_improvement = 0.01, batch_start = 0, batch_end = -1, nframes_prior=0, record_loss_curves = True): """ Count electrons in frames and save results to an HDF5 file. Parameters ---------- frames_file : str Path to the input HDF5 file with raw frames. counted_file : str Path to the output HDF5 file for counted data. frames_per_batch : int Number of frames to process per batch. th_single_elec : float Threshold for single electron detection. baseline : float Baseline value to subtract from frames. gauss_A : float Gaussian amplitude for electron profile. gauss_sigma : float Gaussian standard deviation for electron profile. n_steps_max : int, optional Maximum optimization steps (default: 5000). loss_per_frame_stop : float, optional Loss per frame threshold to stop (default: 1). min_loss_patience : int, optional Patience for loss improvement (default: 10). min_loss_improvement : float, optional Minimum loss improvement (default: 0.01). batch_start : int, optional Starting batch index (default: 0). batch_end : int, optional Ending batch index (default: -1, meaning all batches). nframes_prior : int, optional Number of frames for prior computation (default: 0). record_loss_curves : bool, optional Whether to record loss curves (default: True). Returns ------- tuple - list: Loss curves for each batch. - ndarray: Last batch's counted frames. - ndarray: Last batch's modeled frames (counted frames convoluted with Gaussian profile). """ # Compute the prior if nframes_prior > 0. if(nframes_prior > 0): print(f"Computing the prior...") # Compute the prior. prior_frame = compute_prior(frames_file, nframes_prior, baseline, gauss_A) # Compute the conditional probabilities of having >= 2 counts, given >= 1 count. conditional_prob = compute_conditional_probabilities(prior_frame) # Repeat the conditional probabilities for the number of frames in a batch. conditional_prob_batch = np.repeat(conditional_prob[np.newaxis,:,:], frames_per_batch, axis=0) # Get the total number of frames and scan shape. nframes = -1 scan_shape = (0,0) frame_shape = (0,0) with h5py.File(frames_file, 'r') as f0: data = f0['frames'] nframes = data.shape[0] frame_shape = data.shape[1:] scan_shape = f0['frames'].attrs['scan_dimensions'] print(f"Counting all {nframes} frames for scan of shape {scan_shape}") # Record all loss curves. loss_curves = [] batches = round(nframes / frames_per_batch) print(f"Analyzing in {batches} batches...") if(batch_end < 0): batch_end = batches for batch in range(batches)[batch_start:batch_end]: print(f"\n\n ** BATCH {batch} **") # Get the frames for this batch print("-- Processing frames...") with h5py.File(frames_file, 'r') as f0: data = f0['frames'] # Get the frames frame_bls = np.array(data[batch*frames_per_batch:batch*frames_per_batch+frames_per_batch,:,:],dtype=np.float32) - baseline # Compute an initial counted frame frame_ct = np.rint(frame_bls / gauss_A, out=np.zeros(frame_bls.shape,dtype=np.int16), casting='unsafe') frame_ct[frame_ct < th_single_elec] = 0 # ------------------------------------------------------------------------------- # Count the frames. print("-- Counting frames...") frame_ct_reco, modeled_frame_reco, loss_steps = count_frame_pytorch(frame_bls, frame_ct, gauss_A, gauss_sigma, n_steps_max = n_steps_max, loss_lim = len(frame_bls)*loss_per_frame_stop, min_loss_patience=min_loss_patience, min_loss_improvement=min_loss_improvement) frame_ct_reco[frame_ct_reco < 0] = 0 frame_ct_reco = np.rint(frame_ct_reco) if(record_loss_curves): loss_curves.append(loss_steps) # ------------------------------------------------------------------------------- # ------------------------------------------------------------------------------- # Apply the prior for each frame. if(nframes_prior > 0): # Generate a 576x576 matrix of random numbers from a uniform distribution [0, 1) random_matrix = np.random.rand(*frame_ct_reco.shape) # Compare the random matrix to the probability matrix: # - if the random number is greater than the conditional probability, # the multi-count is said to have been due to the Landau # tail, and therefore the count is set to 1 single_electron_pixels = random_matrix > conditional_prob_batch[:len(random_matrix)] # Set all counts > 1 that did not pass the probability game to 1 count. forced_single_electrons = (frame_ct_reco > 1) & single_electron_pixels n_single_elec = np.sum(forced_single_electrons) n_total = frames_per_batch*frame_shape[0]*frame_shape[1] print(f"{n_single_elec} of {n_total} ({n_single_elec/n_total*100:.4f}%) forced to single-electron counts") frame_ct_reco[forced_single_electrons] = 1 # ------------------------------------------------------------------------------- # Save the counted frames in the array print("-- Saving frames...") frames_indices, frames_weights = frame_to_indices_weights(frame_ct_reco) print(f"Frame indices len = {len(frames_indices)} and weights = {len(frames_weights)}") update_counted_data_hdf5(counted_file, nframes, batch*frames_per_batch, frames_indices, frames_weights, scan_shape, frame_shape) # Return the loss curve for the counting return loss_curves, frame_ct_reco, modeled_frame_reco