# Python code for analysis of phosphorylated rpS6 235/236 immunolabeling
# in the olfactory system of larval Xenopus laevis

def process_nd2(file):
    '''
    Open a file containing an image stack with immunolabelled rpS6 signal.
    Multiple channel recordings are supported. But the filename must 
    indicate which channel is rpS6. Processing includes:
    - Autofluorescence correction (either subtract blue channel or all
      oversaturated pixels are set to 0)
    - Crop image in y dimension (to exclude cellular signals from the
      olfactory bulb)
    - Remove background (high sigma gaussian filtering of each plane)
    - Median filtering 3 px radius (only in x,y dimension)
    - Quantify cumulative rps6 signal intensity
    - Create and save maximum intensity projection overviews for visual
      control

    Source image files are raw image stacks (Nikon nd2-files) acquired
    with an A1R-MP multiphoton microscope.
    The image stack is cropped in y dimension. Comparison of maximum
    intensity projections is saved. The cumulative intensity is
    calculated and returned.
    '''
    import numpy as np
    import matplotlib.pyplot as plt
    import nd2
    from skimage.filters import gaussian
    from skimage.morphology import footprints
    from skimage.io import imsave
    from scipy.ndimage import median_filter

    # import image stack (z,channel,y,x)
    img_stack = np.int16(nd2.imread(file))

    # get from file name which channel contains the rpS6 signal
    ps6_ch = get_channel_from_fname(file)

    # check if blue channel is present for autofluorescence correction
    ch_total_no = img_stack.shape[1]
    if ch_total_no < 3:
        blue_channel_recorded = False
    elif ch_total_no >= 3:
        blue_channel_recorded = True

    if blue_channel_recorded:
        img_stack = img_stack[:,ps6_ch,:,:] - img_stack[:,0,:,:]  # subtract blue channel from signal channel
        img_stack = np.where(img_stack<0, 0, img_stack)  # negative values make no sense
    else:
        # set supersaturated areas to 0
        img_stack = np.where(img_stack[:,ps6_ch-1,:,:]>=4095, 0, img_stack[:,ps6_ch-1,:,:])  # no channel 0 recorded thus shift channel number by one
    
    # crop image stack
    ycutoff = get_ylimits_from_csv(file, 'cropping_range.csv')  # get cropping range from file
    img_stack = img_stack[:,ycutoff[0]:ycutoff[1],:]  # crop image stack on y-axis

    # remove background for each plane (wide gaussian background estimation)
    img_minus_bg, background_stack = subtract_gauss_background(img_stack, sigma=35)

    # create maximum intensity projections for comparison plots
    original_max = img_stack.max(axis=0)
    background_max = background_stack.max(axis=0)
    img_minus_bg_max = img_minus_bg.max(axis=0)
    
    # median filtering after background subtraction
    img_minus_bg = median_filter(img_minus_bg, footprint=np.reshape(footprints.disk(3), (1,7,7)))  # circle footprint
    img_minus_bg_filtered_max = img_minus_bg.max(axis=0)
    
    # get cumulative ps6 intensity of imagestack cleared of background signal
    int_sum = img_minus_bg.sum()
    
    # plot the steps of image processing and save images (for visual control)
    namestem = file.split('/')[-1:][0][:-4]
    fig, axs = plt.subplots(1, 4, figsize=(10,3), dpi=300)
    axs[0].imshow(original_max, vmin=0, vmax=4095)
    axs[0].set_title("Original data", wrap=True)
    axs[1].imshow(background_max, vmin=0, vmax=4095)
    axs[1].set_title("Gaussian filter (σ=35)\nas background estimation", wrap=True)
    axs[2].imshow(img_minus_bg_max, vmin=0, vmax=4095)
    axs[2].set_title("Background subtracted", wrap=True)
    axs[3].imshow(img_minus_bg_filtered_max, vmin=0, vmax=4095)
    axs[3].set_title("Background subtracted + Median filter (r=3px)", wrap=True)
    for ax in axs.flatten():
        ax.axis('off')
    plt.tight_layout()
    plt.savefig('overviews/' + namestem + '.jpg', format='jpg')
    plt.close(fig)
    return int_sum


def gaussian_background(stack, sigma=35):
    '''
    Return Gaussian filter of each plane of an image stack.
    stack: 3d array with dimension order (z,y,x)
    '''
    import numpy as np
    from skimage.filters import gaussian
    
    gauss_filtered = np.array([gaussian(plane, sigma=sigma, preserve_range=True) for plane in stack])
    return gauss_filtered


def subtract_gauss_background(stack, sigma=35):
    '''
    Calculate Gaussian of each plane of an image stack.
    This background estimation is subtracted from the
    original image stack. Values below zero are set to
    zero.
    '''
    import numpy as np
    
    bg = gaussian_background(stack, sigma=sigma)
    filtered_without_bg = stack - bg
    filtered_without_bg = np.where(filtered_without_bg<0, 0, filtered_without_bg)
    return filtered_without_bg, bg


def get_channel_from_fname(fname):
    '''
    Filenames contain strings that document which
    channel (green/488 or red/594) contains the 
    rpS6 235/236 signal.
    The string fname is parsed for identifier substrings
    and a channel number is returned.
    '''
    if 'pS6.235.36_488' in fname:
        ps6_ch = 1
    elif 'pS6.235.36_594' in fname:
        ps6_ch = 2
    else:
        ps6_ch = 'none'
    return ps6_ch


def load_ycutoffs(csvfile):
    '''
    Read a csv-file that contains cropping range for the
    y-axis of imaging files. A dictionary is returned.
    column index of csvfile: filename, ycutoff_begin, ycutoff_end
    '''
    import csv
    
    ycutoffs_dict = {}
    with open(csvfile) as f:
        for line in csv.DictReader(f):
            ycutoffs_dict.update({line['filename']: (int(line['ycutoff_begin']), int(line['ycutoff_end']))})
    return ycutoffs_dict


def get_ylimits_from_csv(filename, csvfile):
    '''
    Get individual y-axis boundaries of ON/OB from csv-file
    that includes filenames and begin and end range.
    '''
    ycutoffs = load_ycutoffs(csvfile)
    limits = ycutoffs[filename.split('/')[-1]]
    return limits


def annotate_significance(ax, pairs, pvalues, significance_thresholds=(0.05, 0.01, 0.001)):
    """
    Annotates a Seaborn boxplot with significance asterisks for given pairs of boxes.
    
    Parameters:
    ax : matplotlib.axes.Axes
        The axis on which the boxplot is drawn.
    pairs : list of tuple
        List of pairs of box indices (x positions) to be compared.
    pvalues : list of float
        List of p-values corresponding to each pair.
    significance_thresholds : tuple of float, optional
        Significance thresholds for *, **, and *** annotations.
    
    Example:
    pairs = [(0, 1), (2, 3)] # Compare box 0 with 1 and box 2 with 3
    pvalues = [0.03, 0.002]
    """
    
    # Ensure pairs and pvalues are of same length
    assert len(pairs) == len(pvalues), "Pairs and p-values lists must have the same length."
    
    # Define the heights for the lines and text above the boxes
    y_max = ax.get_ylim()[1]  # Get the max y value from the axis limits
    y_offset = y_max * 0.1  # Space above the boxes
    line_height = y_max *  0.025  # Height for the bracket line
    
    def get_significance(pvalue):
        """Return asterisks based on the p-value."""
        if pvalue < significance_thresholds[2]:
            return "***"
        elif pvalue < significance_thresholds[1]:
            return "**"
        elif pvalue < significance_thresholds[0]:
            return "*"
        else:
            return "n.s."  # Not significant
    
    for (i, (box1, box2)), pval in zip(enumerate(pairs), pvalues):
        # Get the y positions for the bracket (above the maximum y of the boxes)
        y = y_max + (i + 0.25) * y_offset
        line_y = y + line_height
        
        # Plot the bracket
        ax.plot([box1, box1, box2, box2], [y, line_y, line_y, y], lw=1, color="black")
        
        # Determine the significance level and place the text
        significance = get_significance(pval)
        ax.text((box1 + box2) * 0.5, line_y + (line_height *-0.5), significance,  # was (line_height * 0.5)
                ha='center', va='bottom', color="black", fontsize=10)


def remove_mirrored_duplicates(input_list):
    '''
    Remove mirrored duplicates of pairs in a list.
    example: [1,3] and [3,1] are mirrored pairs
    '''
    def invert(l):
        ''' Invert the order of two elements of a pair element of a list. '''
        l_inverted = [l[1], l[0]]
        return l_inverted
    
    no_duplicates = []
    for pair in input_list:
        if invert(pair) in no_duplicates:
            pass
        else:
            no_duplicates.append(list(pair))   
    return no_duplicates
