nafuma/nafuma/dft/phonons.py
2022-10-09 18:38:46 +02:00

1700 lines
No EOL
56 KiB
Python

import re
from this import d
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import subprocess
import os
import shutil
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,AutoMinorLocator)
from mpl_toolkits.axes_grid.inset_locator import (inset_axes, InsetPosition,
mark_inset)
import importlib
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
from cycler import cycler
import itertools
def get_atoms(path='.'):
poscar = os.path.join(path, 'POSCAR')
with open(poscar, 'r') as poscar:
lines = poscar.readlines()
atoms = lines[5].split()
atom_num = lines[6].split()
atom_num = [int(num) for num in atom_num]
return atoms, atom_num
def get_dimensions(path='.'):
poscar = os.path.join(path, 'POSCAR')
sposcar = os.path.join(path, 'SPOSCAR')
with open(poscar, 'r') as poscar:
lines_pos = poscar.readlines()
with open(sposcar, 'r') as sposcar:
lines_spos = sposcar.readlines()
a_p, b_p, c_p = lines_pos[2].split(), lines_pos[3].split(), lines_pos[4].split()
a_s, b_s, c_s = lines_spos[2].split(), lines_spos[3].split(), lines_spos[4].split()
lattice_params_poscar, lattice_params_sposcar = [a_p, b_p, c_p], [a_s, b_s, c_s]
poscar_new = []
sposcar_new = []
for lp_p, lp_s in zip(lattice_params_poscar, lattice_params_sposcar):
lp_p = np.sqrt(float(lp_p[0])**2 + float(lp_p[1])**2 + float(lp_p[2])**2)
lp_s = np.sqrt(float(lp_s[0])**2 + float(lp_s[1])**2 + float(lp_s[2])**2)
poscar_new.append(lp_p)
sposcar_new.append(lp_s)
dim = [int(lp_s/lp_p) for lp_s, lp_p in zip(sposcar_new, poscar_new)]
return dim
def read_band(band_dir):
''' Reads a band file as written by the function write_phonon_bands() into a pandas DataFrame and returns this. Contains two columns: k-points (the "distance" output in the band.yaml-file by phonopy) and frequencies.
Input:
band_dir: the path to the band-file.
Output:
band: pandas DataFrame containing frequencies of the band along the k-point path specified in the phonopy calculation'''
# Read the band into a pandas DataFrame
band = pd.read_csv(band_dir, delim_whitespace=True, header=None, names=['kpt', 'frequency'])
return band
def read_kpoints(kpoints_dir):
''' Reads a VASP KPOINTS-file in line mode. Returns two lists: special_points_coords, containing the coordinates of the special points in k-space and special_points_labels, the names of these special points.
Requires a KPOINTS-file that is in line mode with special points indicated with a "!".
Input:
kpoints_dir: the path to the KPOINTS-file
Output:
special_points_coords: List of 3D coordinates of the k-space special points
special_points_labels: List of names of the k-space special points'''
# Open the KPOINTS-file and read it line by line, appending each line with a "!" to the special_points list.
special_points = []
with open(kpoints_dir) as kpoints:
lines = kpoints.readlines()
for line in lines:
if '!' in line:
special_points.append(line)
# Go through the special points to separate them into the coordinate and the label for each special point into special_points_coords and special_points_labels respectively.
special_points_coords = []
special_points_labels = []
for special_point in special_points:
if len(special_point.split()) == 5:
special_points_coords.append(special_point.split()[0:3])
special_points_labels.append(special_point.split()[-1])
return special_points_coords, special_points_labels
def get_kpoints_ticks(band):
''' Finds the coordinates for the special points in the 1D-projection given by phonopy (the parameter 'distance' in the band.yaml file). This is to determine the placement of labels and vertical lines in the bandstructure plot.
Input:
band: the path to a band_XX.dat file. Should not matter which one is passed here.
Output:
kpts_ticks: A list of coordinates corresponding to the special points.'''
band = np.genfromtxt(band)
kpts_ticks = []
# Append the first point
kpts_ticks.append(0.)
# Go through all data points - where the x-value repeats, a k-point tick is appended to the list
for j in np.arange(np.shape(band)[0]-1):
if band[j,0]==band[j+1,0]:
kpts_ticks.append(band[j,0])
# Append the last point
kpts_ticks.append(max(band[:,0]))
return kpts_ticks
def get_kpoints_labels(special_points_labels):
''' Takes the raw special point labels from read_kpoints() and writes them in a way to be used in the bandstructure plots.
Where there is a discontinuity in the path, the label is separated with a |.
Input:
special_points_labels: A list of special points as directly read from the KPOINTS-file by read_kpoints()
Ouput:
labels: A list of labels suitable to pass as x-ticks during plotting of the bandstructure plots.'''
# Loop through the raw special points labels list following a set of rules, to extract the labels suitable for plotting
labels = []
for ind, label in enumerate(special_points_labels):
# Add the first label as this will be a separate special point
if ind == 0:
label = '${}$'.format(label) if (label[0] == '\\') else label
labels.append(label)
# Add the last label, as this will also be a separate special point (or will it? Must change this if that is not always the case)
elif ind == len(special_points_labels)-1:
label = '${}$'.format(label) if (label[0] == '\\') else label
labels.append(label)
# Skip every second entry, as they will repeat due to the way the KPOINTS-file is constructed
elif ind%2 != 0:
continue
# Add label if it's continuous (i.e. if the current and previous points are the same), add "previous|current" if discontinuous (i.e. if they are not the same)
else:
if label == special_points_labels[ind-1]:
label = '${}$'.format(label) if (label[0] == '\\') else label # If the special point has a greek letter, such as the gamma point, makes sure that the label is enclosed in $ to be rendered correctly.
labels.append(label)
else:
label = '${}$'.format(label) if (label[0] == '\\') else label # If the special point has a greek letter, such as the gamma point, makes sure that the label is enclosed in $ to be rendered correctly.
previous_label = special_points_labels[ind-1]
previous_label = '${}$'.format(previous_label) if previous_label[0] == '\\' else previous_label # If the special point has a greek letter, such as the gamma point, makes sure that the label is enclosed in $ to be rendered correctly.
labels.append("{}|{}".format(previous_label, label))
return labels
def read_phonon_dos(dos_path):
''' Reads the phonon density of states from a total_dos.dat file as written by phonopy. This file will be generated by the function calculate_phonon_dos() as well as this calls phonopy to calculate the density of states.
Input:
dos_path: the path to the total_dos.dat file. Must include the filename
Output:
df: pandas DataFrame containing the contents of the total_dos.dat file. Two columns, "Frequency" and "DOS". '''
df = pd.read_csv(dos_path, header=None, skiprows=1, delim_whitespace=True)
df.columns = ['Frequency', 'DOS']
return df
def read_phonon_pdos(path, normalise=False, poscar=None):
''' Reads the phonon density of states from a total_dos.dat file as written by phonopy. This file will be generated by the function calculate_phonon_dos() as well as this calls phonopy to calculate the density of states.
Input:
dos_path: the path to the total_dos.dat file. Must include the filename
Output:
df: pandas DataFrame containing the contents of the total_dos.dat file. Two columns, "Frequency" and "DOS". '''
df = pd.read_csv(path, index_col=0)
if normalise and poscar:
atoms, atom_num = get_atoms(poscar)
for atom, num in zip(atoms, atom_num):
df[atom] = df[atom] / num
return df
def write_phonopy_band_path(special_points_coords):
''' Writes the band path used by phonopy to calculate the bandstructure from the raw information as extracted by read_kpoints().
Input:
special_points_coords: list of coordinates for the special points as read by read_kpoints(), that reads a VASP KPOINTS.bands file.
Output:
phonopy_band_path: '''
coords = []
for ind, coord in enumerate(special_points_coords):
# Add the first label
if ind == 0:
coord = "{} {} {} ".format(coord[0], coord[1], coord[2])
coords.append(coord)
# Add the last label
elif ind == len(special_points_coords)-1:
coord = "{} {} {}".format(coord[0], coord[1], coord[2])
coords.append(coord)
# Skip every second entry
elif ind%2 != 0:
continue
# Add label if it's continuous, add "previous|current" if discontinuous
else:
if coord == special_points_coords[ind-1]:
coord = "{} {} {} ".format(coord[0], coord[1], coord[2])
coords.append(coord)
else:
first_coord = "{} {} {}".format(special_points_coords[ind-1][0], special_points_coords[ind-1][1], special_points_coords[ind-1][2])
second_coord = "{} {} {} ".format(coord[0], coord[1], coord[2])
coords.append("{}, {} ".format(first_coord, second_coord))
phonopy_band_path = ''
for coord in coords:
phonopy_band_path = phonopy_band_path + coord
return phonopy_band_path
def write_mesh_conf(atoms, dim, mesh, dos_range=None, pdos=False, atom_num=None, tmax=None):
atom_str = 'ATOM_NAME = '
for atom in atoms:
atom_str += atom + " "
dim_str = 'DIM = '
for d in dim:
dim_str += str(d) + " "
mesh_str = 'MP = '
for m in mesh:
mesh_str += str(m) + " "
dos_str = 'DOS_RANGE = '
for d in dos_range:
dos_str += str(d) + " "
if tmax:
tmax_str = f'TMAX = {tmax}'
if pdos:
pdos_str = 'PDOS ='
atoms_sum = 0
for ind, atom in enumerate(atoms):
for i in range(1,atom_num[ind]+1):
pdos_str += " {}".format(i+atoms_sum)
# Add comma after numbers unless it's the last entry
if ind != len(atom_num)-1:
pdos_str += ','
atoms_sum = atoms_sum + atom_num[ind]
with open('mesh.conf', 'w') as conf:
conf.write(atom_str + '\n' + dim_str + '\n' + mesh_str)
if tmax:
conf.write('\n' + tmax_str)
if dos_range:
conf.write('\n' + dos_str)
if pdos:
conf.write('\n' + pdos_str)
conf.write('\n' + "WRITE_MESH = .FALSE.")
def write_band_conf(atoms, dim, mesh, band, band_points=None):
atom_str = 'ATOM_NAME = '
for atom in atoms:
atom_str += atom + " "
dim_str = 'DIM = '
for d in dim:
dim_str += str(d) + " "
mesh_str = 'MP = '
for m in mesh:
mesh_str += str(m) + " "
band_str = 'BAND = ' + band
if band_points:
band_points_str = "BAND_POINTS = " + band_points
with open('band.conf', 'w') as conf:
if not band_points:
conf.write(atom_str + '\n' + dim_str + '\n' + mesh_str + '\n' + band_str)
else:
conf.write(atom_str + '\n' + dim_str + '\n' + mesh_str + '\n' + band_str + '\n' + band_points_str)
def calculate_phonon_dos(path, atoms, dim, mesh, dos_range=None):
cwd = os.getcwd()
os.chdir(path)
write_mesh_conf(atoms, dim, mesh, dos_range=dos_range)
subprocess.call('phonopy -ps mesh.conf >> phonopy_output.dat', shell=True)
# Make folder and move output in there
os.mkdir('total_dos')
shutil.move('total_dos.pdf', 'total_dos/total_dos.pdf')
shutil.move('total_dos.dat', 'total_dos/total_dos.dat')
shutil.move('mesh.conf', 'total_dos/mesh.conf')
shutil.move('phonopy.yaml', 'total_dos/phonopy.yaml')
shutil.move('phonopy_output.dat', 'total_dos/phonopy_output.dat')
os.chdir(cwd)
def calculate_phonon_pdos(path, atoms, dim, mesh, dos_range=None, atom_num=None, order=None):
''' Calculate the projected phonon DOS. Calls function to write mesh.conf file, and then cleans up the output by summing all the individual contributions per atom to the same species.'''
cwd = os.getcwd()
os.chdir(path)
write_mesh_conf(atoms, dim, mesh, dos_range=dos_range, pdos=True, atom_num=atom_num)
subprocess.call('phonopy -ps mesh.conf >> phonopy_output.dat', shell=True)
df = pd.read_csv('projected_dos.dat', delim_whitespace=True, skiprows=1, header=None, dtype=float)
# Loop over the columns and add according to "atoms" and "atom_num" lists
atoms_sum = 0
for atom, num in zip(atoms, atom_num):
df[atom] = df[1+atoms_sum]
for i in range(2+atoms_sum, atoms_sum+num+1):
df[atom] = df[atom] + df[i]
atoms_sum += num
# Remove all other columns, and rename the first column to "Frequency"
df.drop(df.iloc[:, 1:atoms_sum+1], inplace = True, axis = 1)
df.rename(columns = {0: "Frequency"}, inplace=True)
# If a list is passed to order, this will change the order of the atoms:
if order:
df_temp = pd.DataFrame()
df_temp["Frequency"] = df["Frequency"]
for atom in order:
df_temp[atom] = df[atom]
df = df_temp
# Save the cleaned up DataFrame to file.
df.to_csv('projected_dos_clean.dat')
# Make folder and move output in there
os.mkdir('projected_dos')
shutil.move('partial_dos.pdf', 'projected_dos/partial_dos.pdf')
shutil.move('projected_dos.dat', 'projected_dos/projected_dos.dat')
shutil.move('projected_dos_clean.dat', 'projected_dos/projected_dos_clean.dat')
shutil.move('mesh.conf', 'projected_dos/mesh.conf')
shutil.move('phonopy.yaml', 'projected_dos/phonopy.yaml')
shutil.move('phonopy_output.dat', 'projected_dos/phonopy_output.dat')
os.chdir(cwd)
def calculate_thermal_properties(path, atoms, dim, mesh, dos_range=None, tmax=None):
cwd = os.getcwd()
os.chdir(path)
write_mesh_conf(atoms, dim, mesh, dos_range=dos_range, tmax=tmax)
subprocess.call('phonopy -t mesh.conf >> phonopy_output.dat', shell=True)
with open('phonopy_output.dat', 'r') as f:
lines = f.readlines()
data = []
for ind, line in enumerate(lines):
if line.split():
if "#" in line.split()[0]:
j = 1
while lines[ind+j].split():
data.append(lines[ind+j].split())
j += 1
df = pd.DataFrame(data)
df.columns = ['T', 'F', 'S', 'C_v', 'E']
df.to_csv('thermal_properties.dat')
#Make folder and move output in there
os.mkdir('thermal_properties')
shutil.move('thermal_properties.yaml', 'thermal_properties/thermal_properties.yaml')
shutil.move('thermal_properties.dat', 'thermal_properties/thermal_properties.dat')
shutil.move('mesh.conf', 'thermal_properties/mesh.conf')
shutil.move('phonopy.yaml', 'thermal_properties/phonopy.yaml')
shutil.move('phonopy_output.dat', 'thermal_properties/phonopy_output.dat')
os.chdir(cwd)
def calculate_phonon_bandstructure(path, atoms, dim, mesh, kpoints='KPOINTS.bands', band_points=None):
cwd = os.getcwd()
os.chdir(path)
kpoints_coords, kpoints_labels = read_kpoints(kpoints)
band = write_phonopy_band_path(kpoints_coords)
write_band_conf(atoms, dim, mesh, band, band_points=band_points)
subprocess.call('phonopy band.conf >> phonopy_output.dat', shell=True)
write_phonon_bands()
os.mkdir('dispersion_relation')
shutil.move('band.conf', 'dispersion_relation/band.conf')
shutil.move('band.yaml', 'dispersion_relation/band.yaml')
shutil.move('bands', 'dispersion_relation/bands')
shutil.move('mesh.yaml', 'dispersion_relation/mesh.yaml')
shutil.move('phonopy.yaml', 'dispersion_relation/phonopy.yaml')
shutil.move('phonopy_output.dat', 'dispersion_relation/phonopy_output.dat')
os.chdir(cwd)
def write_phonon_bands(band='band.yaml'):
with open(band, 'r') as f:
lines = f.readlines()
kpoints = []
frequencies = []
for line in lines:
if 'distance' in line:
kpoints.append(line.split()[-1])
if 'frequency' in line:
frequencies.append(line.split()[-1])
number_of_kpoints = len(kpoints)
number_of_bands = len(frequencies) / number_of_kpoints
if not os.path.isdir('bands'):
os.mkdir('bands')
os.chdir('bands')
for i in range(int(number_of_bands)):
with open('band_{}.dat'.format(i+1), 'w') as b:
for ind, kpoint in enumerate(kpoints):
if ind == len(kpoints)-1:
b.write("{} {}".format(kpoint, frequencies[ind*int(number_of_bands)+i]))
else:
b.write("{} {}\n".format(kpoint, frequencies[ind*int(number_of_bands)+i]))
os.chdir('../')
def plot_phonon_dos(dos_path='total_dos.dat', options={}):
required_options = ['xlim', 'ylim', 'flip_xy', 'colours', 'palettes', 'rc_params', 'format_params']
default_options = {
'xlim': None, # x-limits
'ylim': None, # y-limits
'flip_xy': False, # Whether to flip what is plotted on the x- and y-axes respectively. Default is False and plots frequency along x-axis and density of states along y-axis.
'colours': None,
'palettes': [('qualitative', 'Dark2_8'), ('qualitative', 'Paired_12')],
'format_params': {},
'rc_params': {}
}
options = update_options(options=options, required_options=required_options, default_options=default_options)
fig, ax = prepare_plot(options=options)
dos = read_phonon_dos(dos_path=dos_path)
if not options['xlim']:
options['xlim'] = [dos["Frequency"].min(), dos["Frequency"].max()]
if not options['ylim']:
options['ylim'] = [dos["DOS"].min(), dos["DOS"].max()*1.1]
if not options['colours']:
colours = generate_colours(palette=options['palette'])
else:
colours = itertools.cycle(options['colours'])
if options['flip_xy']:
dos.plot(x='DOS', y='Frequency', ax=ax, color=colours[0])
else:
dos.plot(x='Frequency', y='DOS', ax=ax, color=colours[0])
options['plot_kind'] = 'DOS'
fig, ax = prettify_dos_plot(fig=fig, ax=ax, options=options)
ax.get_legend().remove()
return fig, ax
def prepare_plot(options={}):
rc_params = options['rc_params']
format_params = options['format_params']
required_options = ['single_column_width', 'double_column_width', 'column_type', 'width_ratio', 'aspect_ratio', 'compress_width', 'compress_height', 'upscaling_factor', 'dpi']
default_options = {
'single_column_width': 8.3,
'double_column_width': 17.1,
'column_type': 'single',
'width_ratio': '1:1',
'aspect_ratio': '1:1',
'compress_width': 1,
'compress_height': 1,
'upscaling_factor': 1.0,
'dpi': 600,
}
options = update_options(format_params, required_options, default_options)
# Reset run commands
plt.rcdefaults()
# Update run commands if any is passed (will pass an empty dictionary if not passed)
update_rc_params(rc_params)
width = determine_width(options)
height = determine_height(options, width)
width, height = scale_figure(options=options, width=width, height=height)
fig, ax = plt.subplots(figsize=(width, height), dpi=options['dpi'])
return fig, ax
def plot_phonon_pdos(path='projected_dos_clean.dat', options={}):
required_options = ['xlim', 'ylim', 'flip_xy', 'colours', 'palettes', 'normalise', 'poscar', 'atoms', 'rc_params', 'format_params']
default_options = {
'xlim': None, # x-limits
'ylim': None, # y-limits
'flip_xy': False, # Whether to flip what is plotted on the x- and y-axes respectively. Default is False and plots frequency along x-axis and density of states along y-axis.
'colours': None,
'palettes': [('qualitative', 'Dark2_8'), ('qualitative', 'Paired_12')],
'normalise': False,
'poscar': None,
'atoms': [],
'format_params': {},
'rc_params': {}
}
options = update_options(options=options, required_options=required_options, default_options=default_options)
dos = read_phonon_pdos(path=path, normalise=options['normalise'], poscar=options['poscar'])
fig, ax = prepare_plot(options=options)
if not options['xlim']:
options['xlim'] = [dos["Frequency"].min(), dos["Frequency"].max()]
if not options['ylim'] and options['atoms']:
ymin = 0
ymax = 0
for atom in options['atoms']:
if dos[atom].min() < ymin:
ymin = dos[atom].min()
if dos[atom].max() > ymax:
ymax = dos[atom].max()
options['ylim'] = [ymin, ymax*1.1]
if not options['colours']:
colours = generate_colours(palette=options['palette'])
else:
colours = itertools.cycle(options['colours'])
for ind, atom in enumerate(options['atoms']):
if options['flip_xy']:
dos.plot(x=atom, y='Frequency', ax=ax, color=next(colours))
else:
dos.plot(x='Frequency', y=atom, ax=ax, color=next(colours))
options['plot_kind'] = 'PDOS'
prettify_dos_plot(fig=fig, ax=ax, options=options)
return fig, ax
def plot_phonon_bandstructure(band_folder='bands', kpoints='KPOINTS.bands', options={}, title=None, xlim=None, ylim=None, pad_bottom=None, scale=1, square=True, width=None, height=None, dpi=None, rotation=None, xpad=None, ypad=None):
# Get the special points labels
kpoint_coords, kpoint_labels = read_kpoints(kpoints)
kpoint_labels = get_kpoints_labels(kpoint_labels)
# Get current folder and change into the folder containing bands
cwd = os.getcwd()
os.chdir(band_folder)
band_paths = [band for band in os.listdir() if os.path.isfile(band) and band[0:4] == 'band']
# Get the location of the special points along the x-axis
kpoint_ticks = get_kpoints_ticks(band_paths[0])
bands = []
for band_path in band_paths:
bands.append(read_band(band_path))
fig, ax = prepare_plot(options=options)
mod = importlib.import_module("palettable.colorbrewer.%s" % 'qualitative')
colour = getattr(mod, 'Dark2_3').mpl_colors[0]
kpt_min = None
kpt_max = None
freq_min = None
freq_max = None
for band in bands:
if kpt_min == None or band["kpt"].min() < kpt_min:
kpt_min = band["kpt"].min()
if kpt_max == None or band["kpt"].max() > kpt_max:
kpt_max = band["kpt"].max()
if freq_min == None or band["frequency"].min() < freq_min:
freq_min = band["frequency"].min()
if freq_max == None or band["frequency"].max() > freq_max:
freq_max = band["frequency"].max()
band.plot('kpt', 'frequency', ax=ax, color=colour)
if not xlim:
xlim = [kpt_min, kpt_max]
if not ylim:
ylim = [freq_min-freq_max*0.1, freq_max+freq_max*0.1]
ax.get_legend().remove()
prettify_plot(fig=fig, ax=ax, special_points_labels=kpoint_labels, special_points_coords=kpoint_ticks, xlim=xlim, ylim=ylim, title=title, pad_bottom=pad_bottom, scale=scale, rotation=rotation, xpad=xpad, ypad=ypad)
os.chdir(cwd)
def prepare_plot_old(width=None, height=None, square=True, dpi=None, colour_cycle=('qualitative', 'Dark2_8'), temperatureunit='K', energyunit='eV f.u.$^{-1}$', scale=1):
linewidth = 3*scale
axeswidth = 3*scale
plt.rc('lines', linewidth=linewidth)
plt.rc('axes', linewidth=axeswidth)
if square:
if not width:
width = 20
if not height:
height = width
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(width, height), facecolor='w', dpi=dpi)
return fig, ax
def prettify_plot(fig, ax, frequencyunit='THz', special_points_coords=None, special_points_labels=None, xlim=None, ylim=None, title=None, pad_bottom=None, scale=1, rotation=None, xpad=None, ypad=None):
# Set sizes of ticks, labes etc.
ticksize = 30*scale
labelsize = 30*scale
legendsize = 30*scale
titlesize = 30*scale
linewidth = 3*scale
axeswidth = 3*scale
majorticklength = 20*scale
minorticklength = 10*scale
# Set labels on x- and y-axes
if ypad:
ax.set_ylabel('Frequency [{}]'.format(frequencyunit), size=labelsize, labelpad=ypad)
else:
ax.set_ylabel('Frequency [{}]'.format(frequencyunit), size=labelsize)
ax.set_xlabel('')
ax.tick_params(axis='y', direction='in', which='major', right=True, length=10, width=0.5)
ax.tick_params(axis='y', direction='in', which='minor', right=True, length=5, width=0.5)
ax.tick_params(axis='x', direction='in', which='major', bottom=False)
ax.yaxis.set_major_locator(MultipleLocator(5))
ax.yaxis.set_minor_locator(MultipleLocator(2.5))
plt.xticks(fontsize=ticksize)
plt.yticks(fontsize=ticksize)
# Set tick parameters
if special_points_coords:
for coord in special_points_coords:
plt.axvline(coord, color='black', linestyle='--', linewidth=0.5)
plt.xticks(ticks=special_points_coords, labels=special_points_labels, rotation=rotation)
if xlim:
plt.xlim(xlim)
if ylim:
plt.ylim(ylim)
if title:
ax.set_title(title, size=40)
if pad_bottom is not None:
bigax = fig.add_subplot(111)
bigax.set_facecolor([1,1,1,0])
bigax.spines['top'].set_visible(False)
bigax.spines['bottom'].set_visible(True)
bigax.spines['left'].set_visible(False)
bigax.spines['right'].set_visible(False)
bigax.tick_params(labelcolor='w', color='w', direction='in', top=False, bottom=True, left=False, right=False, labelleft=False, pad=pad_bottom)
if xpad:
ax.tick_params(axis='x', pad=xpad)
if ypad:
ax.tick_params(axis='y', pad=ypad)
return fig, ax
def prettify_dos_plot(fig, ax, options, frequencyunit='THz', dosunit='a.u.', xlim=None, ylim=None, title=None, hide_ylabels=False, flip_xy=False, pad_bottom=None, scale=1, pdos=False, colours=None, atoms=None, xpad=None, ypad=None):
required_options = ['plot_kind', 'flip_xy', 'hide_x_labels', 'hide_y_labels', 'xlabel', 'ylabel', 'xunit', 'yunit', 'xlim', 'ylim', 'x_tick_locators', 'y_tick_locators', 'hide_x_ticks', 'hide_y_ticks', 'hide_x_ticklabels', 'hide_y_ticklabels',
'colours', 'palettes', 'title', 'legend', 'legend_position', 'subplots_adjust', 'text']
default_options = {
'plot_kind': 'DOS', # DOS or PDOS
'flip_xy': False,
'hide_x_labels': False, # Whether x labels should be hidden
'hide_x_ticklabels': False,
'hide_x_ticks': False,
'hide_y_labels': False, # whether y labels should be hidden
'hide_y_ticklabels': False,
'hide_y_ticks': False,
'xlabel': 'Frequency',
'ylabel': 'DOS',
'xunit': r'THz', # The unit of the x-values in the curve plot
'yunit': r'a.u.', # The unit of the y-values in the curve and bar plots
'xlim': None,
'ylim': None,
'x_tick_locators': [5, 2.5], # Major and minor tick locators
'y_tick_locators': [10, 5],
'colours': None,
'palettes': [('qualitative', 'Dark2_8'), ('qualitative', 'Paired_12')],
'title': None,
'legend': True,
'legend_position': ['upper center', (0.20, 0.90)], # the position of the legend passed as arguments to loc and bbox_to_anchor respectively
'subplots_adjust': [0.1, 0.1, 0.9, 0.9],
'text': None
}
if 'plot_kind' in options.keys():
if 'ylabel' not in options.keys():
if options['plot_kind'] == 'DOS':
options['ylabel'] = 'Density of states'
elif options['plot_kind'] == 'PDOS':
options['ylabel'] = 'PDOS'
options = update_options(options=options, required_options=required_options, default_options=default_options)
if options['flip_xy']:
# Switch all the x- and y-specific values
options = swap_values(dict=options, key1='xlim', key2='ylim')
options = swap_values(dict=options, key1='xunit', key2='yunit')
options = swap_values(dict=options, key1='xlabel', key2='ylabel')
options = swap_values(dict=options, key1='x_tick_locators', key2='y_tick_locators')
options = swap_values(dict=options, key1='hide_x_labels', key2='hide_y_labels')
# Set labels on x- and y-axes
if not options['hide_y_labels']:
ax.set_ylabel(f'{options["ylabel"]} [{options["yunit"]}]')
else:
ax.set_ylabel('')
if not options['hide_x_labels']:
ax.set_xlabel(f'{options["xlabel"]} [{options["xunit"]}]')
else:
ax.set_xlabel('')
# Hide x- and y- ticklabels
if options['hide_y_ticklabels']:
ax.tick_params(axis='y', direction='in', which='both', labelleft=False, labelright=False)
if options['hide_x_ticklabels']:
ax.tick_params(axis='x', direction='in', which='both', labelbottom=False, labeltop=False)
# Hide x- and y-ticks:
if options['hide_y_ticks']:
ax.tick_params(axis='y', direction='in', which='both', left=False, right=False)
if options['hide_x_ticks']:
ax.tick_params(axis='x', direction='in', which='both', bottom=False, top=False)
# Set multiple locators
ax.yaxis.set_major_locator(MultipleLocator(options['y_tick_locators'][0]))
ax.yaxis.set_minor_locator(MultipleLocator(options['y_tick_locators'][1]))
ax.xaxis.set_major_locator(MultipleLocator(options['x_tick_locators'][0]))
ax.xaxis.set_minor_locator(MultipleLocator(options['x_tick_locators'][1]))
# Set title
if options['title']:
ax.set_title(options['title'])
# Generate colours
if not options['colours']:
colours = generate_colours(palette=options['palette'])
else:
colours = itertools.cycle(options['colours'])
# Create legend
if ax.get_legend():
ax.get_legend().remove()
if options['legend']:
if options['plot_kind'] == 'PDOS' and options['atoms']:
# Create legend
patches = []
for atom in options['atoms']:
patches.append(mpatches.Patch(color=next(colours), label=atom))
fig.legend(handles=patches, loc=options['legend_position'][0], bbox_to_anchor=options['legend_position'][1], frameon=False)
# Adjust where the axes start within the figure. Default value is 10% in from the left and bottom edges. Used to make room for the plot within the figure size (to avoid using bbox_inches='tight' in the savefig-command, as this screws with plot dimensions)
plt.subplots_adjust(left=options['subplots_adjust'][0], bottom=options['subplots_adjust'][1], right=options['subplots_adjust'][2], top=options['subplots_adjust'][3])
# If limits for x- and y-axes is passed, sets these.
if options['xlim'] is not None:
ax.set_xlim(options['xlim'])
if options['ylim'] is not None:
ax.set_ylim(options['ylim'])
# Add custom text
if options['text']:
plt.text(x=options['text'][1][0], y=options['text'][1][1], s=options['text'][0])
return fig, ax
def read_thermal_properties(path, number_of_formula_units=None, convert=True):
kJ = 6.2415064799632E+21
Na = 6.0221415E+23
thermal_properties = pd.read_csv(path, skiprows=1, index_col=0)
thermal_properties.columns = ['T', 'F', 'S', 'Cv', 'E']
if convert:
thermal_properties.F = thermal_properties.F / Na * kJ
thermal_properties.S = thermal_properties.S / Na * kJ
thermal_properties.Cv = thermal_properties.Cv / Na * kJ
thermal_properties.E = thermal_properties.E / Na * kJ
if number_of_formula_units:
thermal_properties.F = thermal_properties.F / number_of_formula_units
thermal_properties.S = thermal_properties.S / number_of_formula_units
thermal_properties.Cv = thermal_properties.Cv / number_of_formula_units
thermal_properties.E = thermal_properties.E / number_of_formula_units
return thermal_properties
def plot_thermal_properties(path, number_of_formula_units=None, convert=True):
thermal_properties = read_thermal_properties(path=path, number_of_formula_units=number_of_formula_units, convert=convert)
thermal_properties.plot(x='T', y=['F', 'S', 'Cv', 'E'])
def get_adjusted_energies(paths, equilibrium_energies, options={}):
required_options = ['plot_kind', 'reference', 'number_of_formula_units', 'xlim', 'ylim', 'flip_xy', 'colours', 'palettes', 'normalise', 'poscar', 'atoms', 'rc_params', 'format_params']
default_options = {
'plot_kind': 'absolute',
'reference': 0,
'number_of_formula_units': None,
'xlim': None, # x-limits
'ylim': None, # y-limits
'flip_xy': False, # Whether to flip what is plotted on the x- and y-axes respectively. Default is False and plots frequency along x-axis and density of states along y-axis.
'colours': None,
'palettes': [('qualitative', 'Dark2_8'), ('qualitative', 'Paired_12')],
'normalise': False,
'poscar': None,
'atoms': [],
'format_params': {},
'rc_params': {}
}
options = update_options(options=options, required_options=required_options, default_options=default_options)
dfs = []
if not options['number_of_formula_units']:
options['number_of_formula_units'] = [None for i in range(len(paths))]
for ind, path in enumerate(paths):
df = read_thermal_properties(path, options['number_of_formula_units'][ind])
dfs.append(df)
for ind, df in enumerate(dfs):
df["adjusted_energy"] = equilibrium_energies[ind] + df["F"]
if options['plot_kind'] == 'difference':
for ind, df in enumerate(dfs):
df["reference_energy"] = dfs[options['reference']]["adjusted_energy"]
df["difference_energy"] = df["adjusted_energy"] - df["reference_energy"]
if options['plot_kind'] == 'relative':
for ind, df in enumerate(dfs):
df["reference_energy"] = dfs[options['reference']]["adjusted_energy"].iloc[0]
df["relative_energy"] = df["adjusted_energy"] - df["reference_energy"]
return dfs
def find_low_energy_structures_at_extremas(dfs):
energy_low_T = -1
low_T_ind = -1
energy_high_T = -1
high_T_ind = -1
for ind, df in enumerate(dfs):
if low_T_ind == -1:
low_T_ind = ind
energy_low_T = df['adjusted_energy'].loc[df['T'] == df['T'].min()].values[0]
elif df['adjusted_energy'].loc[df['T'] == df['T'].min()].values[0] < energy_low_T:
low_T_ind = ind
energy_low_T = df['adjusted_energy'].loc[df['T'] == df['T'].min()].values[0]
if high_T_ind == -1:
high_T_ind = ind
energy_high_T = df['adjusted_energy'].loc[df['T'] == df['T'].max()].values[0]
elif df['adjusted_energy'].loc[df['T'] == df['T'].max()].values[0] < energy_high_T:
high_T_ind = ind
energy_high_T = df['adjusted_energy'].loc[df['T'] == df['T'].max()].values[0]
return [low_T_ind, high_T_ind]
def find_intersection(dfs, ind1, ind2):
intersection = -1
for T in dfs[0]['T']:
if dfs[ind2]['adjusted_energy'].loc[dfs[ind2]['T'] == T].values[0] < dfs[ind1]['adjusted_energy'].loc[dfs[ind1]['T'] == T].values[0]:
intersection = T
break
return intersection
def plot_adjusted_energies(paths, equilibrium_energies, options={}):
''' This function plots the adjusted total energies of a set of structures given a set of thermal properties calculated using phonopy.
paths: List of paths (strings) to the .csv-files with thermal properties.
equilibrium_energies: List of equilibrium energies (floats) of pristine calculations
labels: List of labels (strings) to be shown in the plot
mode: Whether to plot as a difference plot ("difference_plot") or absolute units ("absolute"). Defaults to absolute
difference_reference: Index of which structure should serve as the reference. Defaults to 0.
number_of_formula_units: List of number of formula units per unit cell (int, float) to scale the data properly. Defaults to None, meaning to scaling.
width: Width of the plot. Defaults to None, meaning standard width is used.
width: Height of the plot. Defaults to None, meaning standard height is used.
dpi: Dots per inch. Defaults to None, meaning standard dpi is used.
colour_cycle: Tuple with type of colour scheme from the colorbrewer: http://jiffyclub.github.io/palettable/colorbrewer/
temperatureunit: The unit to plot the temperature in. Only K implemented so far.
energyunit: The unit to plot the energy in. Only eV per f.u. impleneted so far.
inset: Whether or not there should be an inset. This is not very well implemented, and may cause issues. Defaults to False.
inset_lims: The x-limits of the inset. Defaults to None, meaning it will just try to figure it out itself.
'''
required_options = ['plot_kind', 'reference', 'number_of_formula_units', 'labels', 'xlim', 'ylim', 'colours', 'palettes', 'linestyles', 'rc_params', 'format_params', 'inset_xlim', 'inset_ylim', 'draw_intersection_main', 'draw_intersection_inset', 'intersection_indices', 'intersection_lw']
default_options = {
'plot_kind': 'absolute',
'reference': 0,
'number_of_formula_units': None,
'labels': None,
'xlim': None, # x-limits
'ylim': None, # y-limits
'inset_xlim': None,
'inset_ylim': None,
'colours': None,
'palettes': [('qualitative', 'Dark2_8'), ('qualitative', 'Paired_12')],
'linestyles': ['solid', 'dotted', 'dashed'],
'format_params': {},
'rc_params': {},
'draw_intersection_main': False,
'draw_intersection_inset': False,
'intersection_indices': None,
'intersection_lw': None,
}
options = update_options(options=options, required_options=required_options, default_options=default_options)
energy_dfs = get_adjusted_energies(paths=paths, equilibrium_energies=equilibrium_energies, options=options)
fig, ax = prepare_plot(options=options)
if not options['labels']:
options['labels'] = ['_' for i in range(len(paths))]
if not options['colours']:
colours = generate_colours(palettes=options['palettes'])
else:
colours = itertools.cycle(options['colours'])
linestyles = itertools.cycle(options['linestyles'])
for df in energy_dfs:
if options['plot_kind'] == 'difference':
df.plot(x='T', y='difference_energy', ax=ax, ls=next(linestyles), c=next(colours))
elif options['plot_kind'] == 'relative':
df.plot(x='T', y='relative_energy', ax=ax, ls=next(linestyles), c=next(colours))
elif options['plot_kind'] == 'absolute':
df.plot(x='T', y='adjusted_energy', ax=ax, ls=next(linestyles), c=next(colours))
ax.set_xlim([int(df["T"].min()), int(df["T"].max())])
fig, ax = prettify_thermal_plot(fig=fig, ax=ax, options=options)
if options['inset_xlim']:
inset_ax = prepare_inset_axes(ax, options)
if not options['colours']:
colours = generate_colours(palettes=options['palettes'])
else:
colours = itertools.cycle(options['colours'])
linestyles = itertools.cycle(options['linestyles'])
for df in energy_dfs:
if options['plot_kind'] =='absolute':
y = 'adjusted_energy'
elif options['plot_kind'] == 'relative':
y = 'relative_energy'
elif options['plot_kind'] == 'difference':
y = 'difference_energy'
df.loc[(df["T"] >= options['inset_xlim'][0]) & (df["T"] <= options['inset_xlim'][1])].plot(x='T', y=y, ax=inset_ax, ls=next(linestyles), c=next(colours))
inset_ax.set_xlim([options['inset_xlim'][0], options['inset_xlim'][1]])
if options['inset_ylim']:
inset_ax.set_ylim([options['inset_ylim'][0], options['inset_ylim'][1]])
inset_ax.get_legend().remove()
inset_ax.set_xlabel('')
if options['draw_intersection_main'] or options['draw_intersection_inset']:
if not options['intersection_indices']:
options['intersection_indices'] = find_low_energy_structures_at_extremas(energy_dfs)
if not options['intersection_lw']:
options['intersection_lw'] = plt.rcParams['lines.linewidth']
intersection = find_intersection(energy_dfs, options['intersection_indices'][0], options['intersection_indices'][1])
if options['draw_intersection_main']:
ax.axvline(x=intersection, ls='dashed', c='black', lw=options['intersection_lw'])
if options['draw_intersection_inset']:
inset_ax.axvline(x=intersection, ls='dashed', c='black', lw=options['intersection_lw'])
return fig, ax
def prepare_thermal_plot(width=None, height=None, dpi=None, colour_cycle=('qualitative', 'Dark2_8'), temperatureunit='K', energyunit='eV f.u.$^{-1}$', scale=1):
linewidth = 3*scale
axeswidth = 3*scale
plt.rc('lines', linewidth=linewidth)
plt.rc('axes', linewidth=axeswidth)
if not width:
width = 20
if not height:
height = width
fig = plt.figure(figsize=(width, height), facecolor='w', dpi=dpi)
ax = plt.gca()
# Set colour cycle
mod = importlib.import_module("palettable.colorbrewer.%s" % colour_cycle[0])
colors = getattr(mod, colour_cycle[1]).mpl_colors
ax.set_prop_cycle(cycler('color', colors))
return fig, ax
def prettify_thermal_plot(fig, ax, options):
required_options = ['plot_kind', 'hide_x_labels', 'hide_y_labels', 'rotation_x_ticks', 'rotation_y_ticks', 'xlabel', 'ylabel', 'xunit', 'yunit', 'xlim', 'ylim', 'x_tick_locators', 'y_tick_locators', 'hide_x_ticks', 'hide_y_ticks', 'hide_x_ticklabels', 'hide_y_ticklabels',
'colours', 'palettes', 'title', 'legend', 'legend_position', 'subplots_adjust', 'text']
default_options = {
'plot_kind': 'absolute', # absolute, relative, difference
'hide_x_labels': False, # Whether x labels should be hidden
'hide_x_ticklabels': False,
'hide_x_ticks': False,
'rotation_x_ticks': 0,
'hide_y_labels': False, # whether y labels should be hidden
'hide_y_ticklabels': False,
'hide_y_ticks': False,
'rotation_y_ticks': 0,
'xlabel': 'Temperature',
'ylabel': 'Energy',
'xunit': r'K', # The unit of the x-values in the curve plot
'yunit': r'eV f.u.$^{-1}$', # The unit of the y-values in the curve and bar plots
'xlim': None,
'ylim': None,
'x_tick_locators': [100, 50], # Major and minor tick locators
'y_tick_locators': [10, 5],
'labels': None,
'colours': None,
'palettes': [('qualitative', 'Dark2_8'), ('qualitative', 'Paired_12')],
'title': None,
'legend': True,
'legend_position': ['upper center', (0.20, 0.90)], # the position of the legend passed as arguments to loc and bbox_to_anchor respectively
'subplots_adjust': [0.1, 0.1, 0.9, 0.9],
'text': None
}
if 'plot_kind' in options.keys():
if 'ylabel' not in options.keys():
if options['plot_kind'] == 'absolute':
options['ylabel'] = 'Energy'
elif options['plot_kind'] == 'relative':
options['ylabel'] = 'Relative energy'
elif options['plot_kind'] == 'difference':
options['ylabel'] = 'Energy difference'
if 'y_tick_locators' not in options.keys():
if options['plot_kind'] == 'absolute' or options['plot_kind'] == 'relative':
options['y_tick_locators'] = [1, 0.5]
elif options['plot_kind'] == 'difference':
options['y_tick_locators'] = [0.1, 0.05]
options = update_options(options=options, required_options=required_options, default_options=default_options)
# Set labels on x- and y-axes
if not options['hide_y_labels']:
ax.set_ylabel(f'{options["ylabel"]} [{options["yunit"]}]')
else:
ax.set_ylabel('')
if not options['hide_x_labels']:
ax.set_xlabel(f'{options["xlabel"]} [{options["xunit"]}]')
else:
ax.set_xlabel('')
# Set multiple locators
ax.yaxis.set_major_locator(MultipleLocator(options['y_tick_locators'][0]))
ax.yaxis.set_minor_locator(MultipleLocator(options['y_tick_locators'][1]))
ax.xaxis.set_major_locator(MultipleLocator(options['x_tick_locators'][0]))
ax.xaxis.set_minor_locator(MultipleLocator(options['x_tick_locators'][1]))
# Hide x- and y- ticklabels
if options['hide_y_ticklabels']:
ax.tick_params(axis='y', direction='in', which='both', labelleft=False, labelright=False)
else:
plt.xticks(rotation=options['rotation_x_ticks'])
#ax.set_xticklabels(ax.get_xticks(), rotation = options['rotation_x_ticks'])
if options['hide_x_ticklabels']:
ax.tick_params(axis='x', direction='in', which='both', labelbottom=False, labeltop=False)
else:
pass
#ax.set_yticklabels(ax.get_yticks(), rotation = options['rotation_y_ticks'])
# Hide x- and y-ticks:
if options['hide_y_ticks']:
ax.tick_params(axis='y', direction='in', which='both', left=False, right=False)
if options['hide_x_ticks']:
ax.tick_params(axis='x', direction='in', which='both', bottom=False, top=False)
# Set title
if options['title']:
ax.set_title(options['title'])
# Create legend
if ax.get_legend():
ax.get_legend().remove()
if options['legend']:
# Make palette and linestyles from original parameters
if not options['colours']:
colours = generate_colours(palettes=options['palettes'])
else:
colours = itertools.cycle(options['colours'])
linestyles = itertools.cycle(options['linestyles'])
# Create legend
custom_lines = []
active_labels = []
for label in options['labels']:
# Discard next linestyle and colour if label is _
if label == '_':
_ = next(colours)
_ = next(linestyles)
else:
custom_lines.append(Line2D([0], [0], color=next(colours), ls=next(linestyles)))
active_labels.append(label)
ax.legend(custom_lines, active_labels, frameon=False, loc=options['legend_position'][0], bbox_to_anchor=options['legend_position'][1])
#fig.legend(handles=patches, loc=options['legend_position'][0], bbox_to_anchor=options['legend_position'][1], frameon=False)
# Adjust where the axes start within the figure. Default value is 10% in from the left and bottom edges. Used to make room for the plot within the figure size (to avoid using bbox_inches='tight' in the savefig-command, as this screws with plot dimensions)
plt.subplots_adjust(left=options['subplots_adjust'][0], bottom=options['subplots_adjust'][1], right=options['subplots_adjust'][2], top=options['subplots_adjust'][3])
# If limits for x- and y-axes is passed, sets these.
if options['xlim'] is not None:
ax.set_xlim(options['xlim'])
if options['ylim'] is not None:
ax.set_ylim(options['ylim'])
# Add custom text
if options['text']:
plt.text(x=options['text'][1][0], y=options['text'][1][1], s=options['text'][0])
return fig, ax
def prettify_thermal_plot_old(fig, ax, options, colour_cycle=('qualitative', 'Dark2_8'), temperatureunit='K', energyunit='eV f.u.$^{-1}$', mode='absolute', scale=1, linestyles=None, labels=None, xpad=None, ypad=None):
# Set sizes of ticks, labes etc.
ticksize = 30*scale
labelsize = 30*scale
legendsize = 30*scale
titlesize = 30*scale
linewidth = 3*scale
axeswidth = 3*scale
majorticklength = 20*scale
minorticklength = 10*scale
xpad = 4 if not xpad else xpad
ypad = 4 if not ypad else ypad
# Set labels on x- and y-axes
ax.set_xlabel('Temperature [{}]'.format(temperatureunit), size=labelsize, labelpad=xpad)
if mode == 'absolute':
ax.set_ylabel('Total energy [{}]'.format(energyunit), size=labelsize, labelpad=ypad)
elif mode == 'difference_plot':
ax.set_ylabel('Energy difference [{}]'.format(energyunit), size=labelsize, labelpad=ypad)
elif mode == 'relative':
ax.set_ylabel('Relative energy [{}]'.format(energyunit), size=labelsize, labelpad=ypad)
# Set tick parameters
ax.tick_params(axis='x', direction='in', which='major', top=True, length=majorticklength, width=axeswidth, pad=xpad)
ax.tick_params(axis='x', direction='in', which='minor', top=True, length=minorticklength, width=axeswidth)
ax.tick_params(axis='y', direction='in', which='major', right=True, length=majorticklength, width=axeswidth, pad=ypad)
ax.tick_params(axis='y', direction='in', which='minor', right=True, length=minorticklength, width=axeswidth)
ax.xaxis.set_major_locator(MultipleLocator(100))
ax.xaxis.set_minor_locator(MultipleLocator(50))
if mode == 'absolute':
ax.yaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_minor_locator(MultipleLocator(0.5))
elif mode == 'difference_plot':
ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(MultipleLocator(0.05))
if labels:
custom_lines = []
if not linestyles:
linestyles = ['solid', 'dotted', 'dashed']
mod = importlib.import_module("palettable.colorbrewer.%s" % colour_cycle[0])
palette = getattr(mod, colour_cycle[1]).mpl_colors
palette = itertools.cycle(palette)
colours = []
for label in labels:
colours.append(next(palette))
patches = []
for ind, label in enumerate(labels):
custom_lines.append(Line2D([0], [0], color=colours[ind], lw=linewidth, ls=linestyles[ind]))
ax.legend(custom_lines, labels, fontsize=labelsize, frameon=False)
plt.xticks(fontsize=ticksize, rotation=45)
plt.yticks(fontsize=ticksize)
return fig, ax
def prepare_inset_axes(parent_ax, options):
default_options = {
'hide_inset_x_labels': False, # Whether x labels should be hidden
'hide_inset_x_ticklabels': False,
'hide_inset_x_ticks': False,
'rotation_inset_x_ticks': 0,
'hide_inset_y_labels': False, # whether y labels should be hidden
'hide_inset_y_ticklabels': False,
'hide_inset_y_ticks': False,
'rotation_inset_y_ticks': 0,
'inset_x_tick_locators': [100, 50], # Major and minor tick locators
'inset_y_tick_locators': [10, 5],
'inset_position': [0.1,0.1,0.3,0.3],
'legend_position': ['upper center', (0.20, 0.90)], # the position of the legend passed as arguments to loc and bbox_to_anchor respectively
'connecting_corners': [1,2]
}
options = update_options(options=options, required_options=default_options.keys(), default_options=default_options)
# Create a set of inset Axes: these should fill the bounding box allocated to
# them.
inset_ax = plt.axes([0, 0, 2, 2])
# Manually set the position and relative size of the inset axes within ax1
ip = InsetPosition(parent_ax, options['inset_position'])
inset_ax.set_axes_locator(ip)
mark_inset(parent_ax, inset_ax, loc1a=2, loc2a=4, loc1b=1, loc2b=2, fc='none', ec='black')
inset_ax.xaxis.set_major_locator(MultipleLocator(options['inset_x_tick_locators'][0]))
inset_ax.xaxis.set_minor_locator(MultipleLocator(options['inset_x_tick_locators'][1]))
inset_ax.yaxis.set_major_locator(MultipleLocator(options['inset_y_tick_locators'][0]))
inset_ax.yaxis.set_minor_locator(MultipleLocator(options['inset_y_tick_locators'][1]))
return inset_ax
def update_rc_params(rc_params):
''' Update all passed run commands in matplotlib'''
if rc_params:
for key in rc_params.keys():
plt.rcParams.update({key: rc_params[key]})
def update_options(options, required_options, default_options):
''' Update all passed options'''
for option in required_options:
if option not in options.keys():
options[option] = default_options[option]
return options
def determine_width(options):
conversion_cm_inch = 0.3937008 # cm to inch
if options['column_type'] == 'single':
column_width = options['single_column_width']
elif options['column_type'] == 'double':
column_width = options['double_column_width']
column_width *= conversion_cm_inch
width_ratio = [float(num) for num in options['width_ratio'].split(':')]
width = column_width * width_ratio[0]/width_ratio[1]
return width
def determine_height(options, width):
aspect_ratio = [float(num) for num in options['aspect_ratio'].split(':')]
height = width/(aspect_ratio[0] / aspect_ratio[1])
return height
def scale_figure(options, width, height):
width = width * options['upscaling_factor'] * options['compress_width']
height = height * options['upscaling_factor'] * options['compress_height']
return width, height
def swap_values(dict, key1, key2):
key1_val = dict[key1]
dict[key1] = dict[key2]
dict[key2] = key1_val
return dict
def generate_colours(palettes):
# Creates a list of all the colours that is passed in the colour_cycles argument. Then makes cyclic iterables of these.
colour_collection = []
for palette in palettes:
mod = importlib.import_module("palettable.colorbrewer.%s" % palette[0])
colour = getattr(mod, palette[1]).mpl_colors
colour_collection = colour_collection + colour
colour_cycle = itertools.cycle(colour_collection)
return colour_cycle