Advertisement
OreganoHauch

violin plot (02-01-2021)

Jan 2nd, 2021
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.29 KB | None | 0 0
  1. def violin(kmeans_input=7,currents=False, title=True, data_processing=1, dpi=None, savepng=False):
  2. start = time.time()
  3.  
  4. figsize = (30*0.75,15*0.75)
  5. fontsize = figsize[0]
  6. fig, ax = plt.subplots(figsize=figsize)
  7.  
  8. scan_order = [185,183,181,182,186]
  9.  
  10. xlabels = ["-200","-100","0","+100","+200"]
  11.  
  12. if currents:
  13. if data_processing == 1:
  14. current_array1 = np.array(load_cache_currents()["pixelplotting_array"][scan_order.index(185)]).ravel()[np.newaxis,:].T
  15. current_array2 = np.array(load_cache_currents()["pixelplotting_array"][scan_order.index(183)]).ravel()[np.newaxis,:].T
  16. current_array3 = np.array(load_cache_currents()["pixelplotting_array"][scan_order.index(181)]).ravel()[np.newaxis,:].T
  17. current_array4 = np.array(load_cache_currents()["pixelplotting_array"][scan_order.index(182)]).ravel()[np.newaxis,:].T
  18. current_array5 = np.array(load_cache_currents()["pixelplotting_array"][scan_order.index(186)]).ravel()[np.newaxis,:].T
  19. if data_processing == 3:
  20. current_array1 = np.array(load_cache_currents()["pixelplotting_array_direct"][scan_order.index(185)]).ravel()[np.newaxis,:].T
  21. current_array1 = current_array1 - np.nanmin(current_array1)
  22. current_array2 = np.array(load_cache_currents()["pixelplotting_array_direct"][scan_order.index(183)]).ravel()[np.newaxis,:].T
  23. current_array2 = current_array2 - np.nanmin(current_array2)
  24. current_array3 = np.array(load_cache_currents()["pixelplotting_array_direct"][scan_order.index(181)]).ravel()[np.newaxis,:].T
  25. current_array3 = current_array3 - np.nanmin(current_array3)
  26. current_array4 = np.array(load_cache_currents()["pixelplotting_array_direct"][scan_order.index(182)]).ravel()[np.newaxis,:].T
  27. current_array4 = current_array4 - np.nanmin(current_array4)
  28. current_array5 = np.array(load_cache_currents()["pixelplotting_array_direct"][scan_order.index(186)]).ravel()[np.newaxis,:].T
  29. current_array5 = current_array5 - np.nanmin(current_array5
  30. numpy_data = np.concatenate((current_array1,current_array2,current_array3,current_array4,current_array5),axis=1)
  31. df = pd.DataFrame(data=numpy_data, columns=xlabels)
  32. ax = sn.violinplot(data=df, scale="count",inner="box")
  33.  
  34. else:
  35. numpy_dataDic = {}
  36. for scan_index in scan_order:
  37. current_array_scan = np.array(load_cache_currents()["pixelplotting_array"][scan_order.index(scan_index)])
  38. data = current_array_scan.ravel()[np.newaxis,:].T
  39. scan = np.repeat(scan_index,len(data))[np.newaxis,:].T
  40. group = np.array(load_cache_kmeans(kmeans_input = kmeans_input)["2"]["cluster_numbers"]).ravel()[np.newaxis,:].T
  41. numpy_dataDic[scan_index] = np.concatenate((data,scan,group),axis=1)
  42. #plt.axvline(x=scan_index-181.5, color="grey", lw=1)
  43. numpy_data = np.concatenate((numpy_dataDic[185],numpy_dataDic[183],numpy_dataDic[181],numpy_dataDic[182],numpy_dataDic[186]))
  44. df = pd.DataFrame(data=numpy_data, columns=["XBIC","Bias voltage (mV)","Group"])
  45. df = df.replace(1.0,"Group 1")
  46. df = df.replace(2.0,"Group 2")
  47. ax = sn.violinplot(x = "Bias voltage (mV)", y = "XBIC", hue = "Group", data = df, order=[185,183,181,182,186], palette = "viridis", split=True, linewidth = 0, scale = "count", inner = "box", ax = ax)
  48.  
  49. ax.set_xlabel("Bias voltage (mV)", labelpad = 10, fontsize=9/5*fontsize, fontweight="bold")
  50. ax.set_ylabel("XBIC", labelpad = 10, fontsize=9/5*fontsize, fontweight="bold")
  51. ax.tick_params(labelsize=fontsize, length=4/5*fontsize)
  52.  
  53. ax.xaxis.set_ticks([0,1,2,3,4])
  54. ax.xaxis.set_ticklabels(xlabels, fontsize=fontsize)
  55.  
  56. # add minor ticks for XBIC
  57. ax.yaxis.set_minor_locator(MultipleLocator(10))
  58. ax.tick_params(which='minor', length=2/5*fontsize)
  59. plt.grid(axis="y", which="minor",lw=0.25)
  60.  
  61. plt.grid(axis="y")
  62. if title:
  63. plt.title("Current distribution", pad = 15, fontsize=round(9/5*fontsize), fontweight="bold")
  64.  
  65. # add vertical lines
  66. for x in range(5):
  67. plt.axvline(x=x+0.5, color="grey", lw=1)
  68.  
  69. # meta = {
  70. # "colour": {"a": "red", "b": "blue"},
  71. # "label": {"a": "this a", "b": "this b"},
  72. # }
  73.  
  74. # handles = [
  75. # Patch(
  76. # facecolor=meta["colour"][x],
  77. # label=meta["label"][x],
  78. # )
  79. # for x in meta["colour"].keys()
  80. # ]
  81.  
  82. # fig.legend(
  83. # handles=handles,
  84. # ncol=2,
  85. # loc="upper right",
  86. # bbox_to_anchor=(0.5, 0.5),
  87. # fontsize=10,
  88. # handlelength=1,
  89. # handleheight=1,
  90. # frameon=False,
  91. # )
  92.  
  93. ax.legend(prop={'size': fontsize})
  94. #ax.get_legend().remove()
  95.  
  96. plt.show()
  97.  
  98. end = time.time()
  99. print(f"Plotting of the violin plots took {str(round(end-start,2))} seconds.")
  100.  
  101. if savepng:
  102. now = datetime.now()
  103. dt_string = now.strftime("%d-%m-%Y_%H_%M_%S")
  104. if dpi is None:
  105. fig.savefig("savefig/violin_" + dt_string + ".png", dpi=fig.dpi, bbox_inches="tight")
  106. else:
  107. fig.savefig("savefig/violin_" + dt_string + ".png", dpi=dpi, bbox_inches="tight")
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement