From 47199f11797856592c0955c4f11dbfa33264d1b6 Mon Sep 17 00:00:00 2001 From: jfnavarro Date: Tue, 1 Nov 2016 14:23:41 +0100 Subject: [PATCH] Fixed a small bug with 3 dimensional unsupervised learning --- scripts/differential_analysis.py | 3 +- scripts/unsupervised.py | 4 +- setup.py | 2 +- stanalysis/visualization.py | 78 ++++++++++++++++++++++++++++---- 4 files changed, 74 insertions(+), 13 deletions(-) diff --git a/scripts/differential_analysis.py b/scripts/differential_analysis.py index 8aee59f..c03a7e6 100644 --- a/scripts/differential_analysis.py +++ b/scripts/differential_analysis.py @@ -123,8 +123,9 @@ def main(input_data, data_classes, conditions_tuples, outdir): dea_results.to_csv(os.path.join(outdir, "dea_results_{}.tsv".format(cond)), sep="\t") # Volcano plot print "Generating plots..." + # TODO add colors according to differently expressed or not scatter_plot(dea_results["log2FoldChange"], -np.log10(dea_results["pvalue"]), - xlabel="Log2FoldChange", ylabel="-log10(pvalue)", + xlabel="Log2FoldChange", ylabel="-log10(pvalue)", colors=None, title="Volcano plot", output=os.path.join(outdir, "volcano_{}.png".format(cond))) if __name__ == '__main__': diff --git a/scripts/unsupervised.py b/scripts/unsupervised.py index dd82e90..d8af0d7 100644 --- a/scripts/unsupervised.py +++ b/scripts/unsupervised.py @@ -164,9 +164,9 @@ def main(counts_table_files, y_min = min(reduced_data[:,1]) x_p = reduced_data[:,0] y_p = reduced_data[:,1] - z_p = y_p + z_p = None if num_dimensions == 3: - z_p = reduced_data[:,3] + z_p = reduced_data[:,2] z_max = max(reduced_data[:,2]) z_min = min(reduced_data[:,2]) for x,y,z in zip(x_p,y_p,z_p): diff --git a/setup.py b/setup.py index d2d1f99..db054d9 100755 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( name = 'stanalysis', - version = "0.2.0", + version = "0.2.1", description = __doc__.split("\n", 1)[0], long_description = long_description, keywords = 'rna-seq analysis spatial transcriptomics toolkit', diff --git a/stanalysis/visualization.py b/stanalysis/visualization.py index 4a3c19b..b425476 100644 --- a/stanalysis/visualization.py +++ b/stanalysis/visualization.py @@ -48,8 +48,8 @@ def histogram(x_points, output, title="Histogram", xlabel="X", fig.set_size_inches(16, 16) fig.savefig(output, dpi=300) -def scatter_plot3d(x_points, y_points, z_points, colors, - output, cmap=None, title='Scatter', xlabel='X', +def scatter_plot3d(x_points, y_points, z_points, output, + colors=None, cmap=None, title='Scatter', xlabel='X', ylabel='Y', zlabel="Z", alpha=1.0, size=50): """ This function makes a scatter 3d plot of a set of points (x,y,z). @@ -58,13 +58,14 @@ def scatter_plot3d(x_points, y_points, z_points, colors, :param x_points: a list of x coordinates :param y_points: a list of y coordinates :param z_points: a list of z coordinates (optional) - :param colors: a color label for each point + :param output: the name/path of the output file + :param colors: a color label for each point (can be None) + :param alignment: an alignment 3x3 matrix (pass identity to not align) :param cmap: Matplotlib color mapping object (optional) :param title: the title for the plot :param xlabel: the name of the X label :param ylabel: the name of the Y label - :param zlabel: the name of the Z label - :param output: the name/path of the output file + :param image: the path to the image file :param alpha: the alpha transparency level for the dots :param size: the size of the dots :raises: RuntimeError @@ -77,7 +78,7 @@ def scatter_plot3d(x_points, y_points, z_points, colors, color_list = set(colors) color_values = [color_map[i] for i in color_list] cmap = ListedColormap(color_values) - else: + elif colors is None: colors = "blue" a.scatter(x_points, y_points, @@ -98,7 +99,7 @@ def scatter_plot3d(x_points, y_points, z_points, colors, fig.set_size_inches(16, 16) fig.savefig(output, dpi=300) -def scatter_plot(x_points, y_points, colors, output, +def scatter_plot(x_points, y_points, output, colors=None, alignment=None, cmap=None, title='Scatter', xlabel='X', ylabel='Y', image=None, alpha=1.0, size=50): """ @@ -110,14 +111,73 @@ def scatter_plot(x_points, y_points, colors, output, The plot will be written to a file. :param x_points: a list of x coordinates :param y_points: a list of y coordinates - :param colors: a color label for each point + :param output: the name/path of the output file + :param colors: a color label for each point (can be None) + :param alignment: an alignment 3x3 matrix (pass identity to not align) :param cmap: Matplotlib color mapping object (optional) :param title: the title for the plot :param xlabel: the name of the X label :param ylabel: the name of the Y label :param image: the path to the image file + :param alpha: the alpha transparency level for the dots + :param size: the size of the dots + :raises: RuntimeError + """ + # Plot spots with the color class in the tissue image + fig = plt.figure(figsize=(16,16)) + a = fig.add_subplot(111, aspect='equal') + base_trans = a.transData + extent_size = (1,33,35,1) + # If alignment is None we re-size the image to chip size (1,1,33,35) + if alignment is not None: + base_trans = transforms.Affine2D(matrix = alignment) + base_trans + extent_size = None + color_values = None + if cmap is None and colors is not None: + color_list = set(colors) + color_values = [color_map[i] for i in color_list] + cmap = ListedColormap(color_values) + elif colors is None: + colors = "blue" + a.scatter(x_points, + y_points, + c=colors, + cmap=cmap, + edgecolor="none", + s=size, + transform=base_trans, + alpha=alpha) + if image is not None and os.path.isfile(image): + img = plt.imread(image) + # TODO imgshow() will not work if I pass extent_size as variable + a.imshow(img, extent=(1,33,35,1)) + a.set_xlabel(xlabel) + a.set_ylabel(ylabel) + if color_values is not None: + a.legend([plt.Line2D((0,1),(0,0), color=x) for x in color_values], + color_list, loc="upper right", markerscale=1.0, + ncol=1, scatterpoints=1, fontsize=10) + a.set_title(title, size=20) + fig.set_size_inches(16, 16) + fig.savefig(output, dpi=300) + +def volcano_plot(x_points, y_points, output, + title='Volcano', xlabel='X', ylabel='Y'): + """ + This function makes a Volcano plot with the given + log2foldChange values (X) and -log10(pvalues) (Y). + It will color the most differently expressed + values (using the p-value threshold) as + :param x_points: a list of x coordinates + :param y_points: a list of y coordinates :param output: the name/path of the output file + :param colors: a color label for each point (can be None) :param alignment: an alignment 3x3 matrix (pass identity to not align) + :param cmap: Matplotlib color mapping object (optional) + :param title: the title for the plot + :param xlabel: the name of the X label + :param ylabel: the name of the Y label + :param image: the path to the image file :param alpha: the alpha transparency level for the dots :param size: the size of the dots :raises: RuntimeError @@ -136,7 +196,7 @@ def scatter_plot(x_points, y_points, colors, output, color_list = set(colors) color_values = [color_map[i] for i in color_list] cmap = ListedColormap(color_values) - else: + elif colors is None: colors = "blue" a.scatter(x_points, y_points,