1228 lines
No EOL
44 KiB
Python
1228 lines
No EOL
44 KiB
Python
import re
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import os
|
|
import linecache
|
|
import nafuma.dft as dft
|
|
import nafuma.auxillary as aux
|
|
import nafuma.plotting as btp
|
|
|
|
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,AutoMinorLocator)
|
|
import importlib
|
|
import mpl_toolkits.axisartist as axisartist
|
|
from cycler import cycler
|
|
import itertools
|
|
import matplotlib.patches as mpatches
|
|
|
|
|
|
|
|
|
|
def count_electrons(pdos, orbital, interval=None, r=None, scale=None):
|
|
''' Counts electrons the specified oribtals from a projected density of states DataFrame. Interval can be specified, as well as a scaling factor and whether the number should be rounded.
|
|
Inputs:
|
|
dos: either an individual DOS (as read from read_pdos()), or a list of DOSes. If a single DataFrame is passed, it will be appended to a list
|
|
orbital: list of which orbitals to count the electrons from
|
|
interval: a list specifying where to start counting from (lower limit) to where to stop counting (upper limit) in eV
|
|
r: Number of decimals points the number should be rounded to
|
|
scale: A scaling factor to scale the number of electrons to a desired size, e.g. if you have a set containing two atoms per unit cell and you want to know how many electrons per atom there is
|
|
|
|
Output:
|
|
nelec: The total number of electrons given your choices
|
|
nelec_dos: A list where each element is the total number of electrons per DOS passed (e.g. you pass three PDOS from three individual atoms, then you will get total electron count per atom)
|
|
nelec_orbitals: A list of lists, where each list contains the number of electrons per orbital specified (e.g. you pass three PDOS from three individual atoms, you will get three lists each containing electron count per orbital specified)
|
|
'''
|
|
|
|
|
|
|
|
if not type(orbital) == list:
|
|
orbital = [orbital]
|
|
|
|
if not type(pdos) == list:
|
|
pdos = [pdos]
|
|
|
|
|
|
nelec = 0
|
|
nelec_per_dos = []
|
|
nelec_per_orbital = []
|
|
|
|
for d in pdos:
|
|
|
|
energy = d["Energy"]
|
|
|
|
nelec_orbitals = []
|
|
|
|
for o in orbital:
|
|
orbital_dos = d[o]
|
|
dE = (energy.max()-energy.min()) / len(energy)
|
|
|
|
|
|
if not interval:
|
|
interval = [energy.min(), energy.max()]
|
|
|
|
emin, emax = interval[0], interval[1]
|
|
|
|
|
|
nelec_orbital = 0
|
|
for od, e in zip(orbital_dos, energy):
|
|
if e > emin and e < emax:
|
|
nelec_orbital += np.abs(od)*dE
|
|
#print(nelec_orbital)
|
|
|
|
nelec += nelec_orbital
|
|
nelec_orbitals.append(nelec_orbital)
|
|
|
|
|
|
# Scale the values if specified
|
|
if scale:
|
|
|
|
for ind, nelec_orbital in enumerate(nelec_orbitals):
|
|
nelec_orbitals[ind] = nelec_orbital * scale
|
|
|
|
|
|
# If rounding is specified, does so to the electron count per DOS and the electron count per orbital
|
|
if r:
|
|
# First sums the electron count per orbital, and then round this number
|
|
nelec_dos = np.round(sum(nelec_orbitals), r)
|
|
|
|
# Then each individual orbital electron count
|
|
for ind, nelec_orbital in enumerate(nelec_orbitals):
|
|
nelec_orbitals[ind] = np.round(nelec_orbital, r)
|
|
|
|
# If no rounding is specified, just adds the electron count per orbital together
|
|
else:
|
|
nelec_dos = sum(nelec_orbitals)
|
|
|
|
# Appends the total electron count for this DOS to the list of all individual DOS electron count
|
|
nelec_per_dos.append(nelec_dos)
|
|
|
|
# Appends the list of orbital electron counts to the list of all the individual DOS orbital electron count (phew...)
|
|
nelec_per_orbital.append(nelec_orbitals)
|
|
|
|
|
|
# The total electron count is then scaled in the end. At this point the other values will have been scaled already
|
|
if scale:
|
|
nelec = nelec * scale
|
|
|
|
# And lastly round if this is specified. Again, the electron counts in the lists are already rounded so they don't have to be rounded again
|
|
if r:
|
|
nelec = np.round(nelec, r)
|
|
|
|
return nelec, [nelec_per_dos, nelec_per_orbital]
|
|
|
|
|
|
|
|
def integrate_coop(coopcar, interval=None, r=None, scale=None, interactions=None, kind='individual', up=True, down=True, collapse=False):
|
|
''' As of now does not support not passing in interactions. Very much copy and paste from the plotting function - not every choice here might make sense for integration of COOP'''
|
|
|
|
coopcar, coop_interactions = dft.io.read_coop(coopcar, collapse=collapse)
|
|
|
|
# If interactions has been specified
|
|
if interactions:
|
|
|
|
# Make interactions into a list of lists for correct looping below
|
|
if type(interactions[0]) != list:
|
|
interactions_list = [interactions]
|
|
else:
|
|
interactions_list = interactions
|
|
|
|
for ind, interactions in enumerate(interactions_list):
|
|
|
|
# Determine which columns to integrate if collapse is enabled
|
|
if collapse:
|
|
to_integrate = [2*(i-1)+3 for i in interactions]
|
|
|
|
|
|
# Make mean column for integration if mean mode is enabeld (is this sensible to include?)
|
|
if kind == 'avg' or kind == 'average' or kind == 'mean':
|
|
coopcar["mean"] = coopcar.iloc[:, to_integrate].mean(axis=1)
|
|
to_integrate = [coopcar.columns.get_loc('mean')]
|
|
|
|
# Determine which columns to integrate if collapse is disabled and both up and down should be plotted
|
|
elif up and down:
|
|
to_integrate_up = [2*(i-1)+3 for i in interactions]
|
|
to_integrate_down = [2*(i-1)+5 +2*len(coop_interactions) for i in interactions]
|
|
to_integrate = to_integrate_up + to_integrate_down
|
|
|
|
if kind == 'avg' or kind == 'average' or kind == 'mean':
|
|
coopcar["mean_up"] = coopcar.iloc[:, to_integrate_up].mean(axis=1)
|
|
coopcar["mean_down"] = coopcar.iloc[:, to_integrate_down].mean(axis=1)
|
|
to_integrate = [coopcar.columns.get_loc('mean_up'), coopcar.columns.get_loc('mean_down')]
|
|
|
|
# Determine which columns to plot if collapse is disabled and only up should be plotted
|
|
elif up:
|
|
to_integrate = [2*(i-1)+3 for i in interactions]
|
|
|
|
if kind == 'avg' or kind == 'average' or kind == 'mean':
|
|
coopcar["mean_up"] = coopcar.iloc[:, to_integrate].mean(axis=1)
|
|
to_integrate = [coopcar.columns.get_loc('mean_up')]
|
|
|
|
|
|
# Determine which columns to plot if collapse is disabled and only down should be plotted
|
|
elif down:
|
|
to_integrate = [2*(i-1)+5 +2*len(coop_interactions) for i in interactions]
|
|
|
|
if kind == 'avg' or kind == 'average' or kind == 'mean':
|
|
coopcar["mean_down"] = coopcar.iloc[:, to_integrate].mean(axis=1)
|
|
to_integrate = [coopcar.columns.get_loc('mean_down')]
|
|
|
|
|
|
|
|
bonding = 0
|
|
antibonding = 0
|
|
bonding_interactions = []
|
|
antibonding_interactions = []
|
|
difference_interactions = []
|
|
percentage_bonding_interactions = []
|
|
|
|
|
|
for integrate in to_integrate:
|
|
|
|
bonding_interaction = 0
|
|
antibonding_interaction = 0
|
|
|
|
coop = coopcar.iloc[:, integrate]
|
|
|
|
energy = coopcar["Energy"]
|
|
dE = (energy.max()-energy.min()) / len(energy)
|
|
|
|
# Sets interval to everything below the Fermi-level by default if not specified in function call
|
|
if not interval:
|
|
interval = [energy.min(), 0]
|
|
|
|
emin, emax = interval[0], interval[1]
|
|
|
|
|
|
for c, e in zip(coop, energy):
|
|
if e > emin and e < emax:
|
|
if c > 0:
|
|
bonding_interaction += c*dE
|
|
elif c < 0:
|
|
antibonding_interaction += np.abs(c)*dE
|
|
|
|
|
|
bonding += bonding_interaction
|
|
antibonding += antibonding_interaction
|
|
|
|
difference_interaction = bonding_interaction - antibonding_interaction
|
|
percentage_bonding_interaction = bonding_interaction / (bonding_interaction + antibonding_interaction) * 100
|
|
|
|
if scale:
|
|
bonding_interaction = bonding_interaction * scale
|
|
antibonding_interaction = antibonding_interaction * scale
|
|
difference_interaction = difference_interaction * scale
|
|
|
|
if r:
|
|
bonding_interaction = np.round(bonding_interaction, r)
|
|
antibonding_interaction = np.round(antibonding_interaction, r)
|
|
difference_interaction = np.round(difference_interaction, r)
|
|
percentage_bonding_interaction = np.round(percentage_bonding_interaction, r)
|
|
|
|
bonding_interactions.append(bonding_interaction)
|
|
antibonding_interactions.append(antibonding_interaction)
|
|
difference_interactions.append(difference_interaction)
|
|
percentage_bonding_interactions.append(percentage_bonding_interaction)
|
|
|
|
difference = bonding - antibonding
|
|
percentage_bonding = (bonding/(bonding+antibonding)) * 100
|
|
|
|
if scale:
|
|
bonding = bonding * scale
|
|
antibonding = antibonding * scale
|
|
difference = difference * scale
|
|
|
|
if r:
|
|
bonding = np.round(bonding, r)
|
|
antibonding = np.round(antibonding, r)
|
|
difference = np.round(difference, r)
|
|
percentage_bonding = np.round(percentage_bonding, r)
|
|
|
|
return [bonding, antibonding, difference, percentage_bonding], [bonding_interactions, antibonding_interactions, difference_interactions, percentage_bonding_interactions]
|
|
|
|
|
|
|
|
|
|
def plot_pdos(data: dict, options={}):
|
|
|
|
default_options = {
|
|
'xlabel': 'Energy', 'xunit': 'eV', 'xlim': None, 'x_tick_locators': None,
|
|
'ylabel': 'Partial density of states', 'yunit': 'arb.u.', 'ylim': None, 'y_tick_locators': None,
|
|
'mark_fermi_level': True, # Adds a dashed line to mark the Fermi-level
|
|
'flip_axes': False, # Flips x- and y-axes
|
|
'plot_indices': [], # List which indices to plot. If options["sum_atoms"] == True, this needs to be a list of lists, each specifying the index of a given atom
|
|
'plot_atoms': [], # List of which atoms to plot. Only used if options["sum_atoms"] == True.
|
|
'plot_orbitals': [], # List of which orbitals to plot. If options["sum_atoms"] == True, this needs to be a list of lists, each specifying the orbitals of a given atom
|
|
'atom_colours': [], # Colours of each atom. Should be a colour for each atom, only in use if options["sum_atoms"] == True.
|
|
'orbital_colours': [], # Colours of each orbital. The list should always correspond to the shape of options["plot_orbitals"].
|
|
'fill': False,
|
|
'fig': None, # Matplotlib Figure object
|
|
'ax': None, # Matplotlib Axes object
|
|
}
|
|
|
|
options = aux.update_options(options=options, default_options=default_options)
|
|
|
|
if 'axes_flipped' not in options.keys():
|
|
options['axes_flipped'] = False
|
|
|
|
data = dft.io.read_pdos(data=data, options=options)
|
|
|
|
|
|
|
|
if not options['fig'] and not options['ax']:
|
|
fig, ax = btp.prepare_plot(options=options)
|
|
else:
|
|
fig, ax = options['fig'], options['ax']
|
|
|
|
|
|
# If options['sum_atoms'] == True
|
|
if isinstance(data['pdos'], dict):
|
|
|
|
# Populate the plot_atoms and plot_orbitals lists if they are not passed. Defaults to showing everything
|
|
if not options['plot_atoms']:
|
|
options['plot_atoms'] = data['atoms']['specie']
|
|
|
|
if not options['plot_orbitals']:
|
|
for atom in options['plot_atoms']:
|
|
options['plot_orbitals'].append([])
|
|
|
|
# This is to fill in each orbital list for each atom. This is in case options['plot_orbitals'] is passes, but one or more of the atoms lack colours
|
|
for i, atom in enumerate(options['plot_atoms']):
|
|
if not options['plot_orbitals'] or not options['plot_orbitals'][i]:
|
|
options['plot_orbitals'][i] = [orbital for orbital in data['pdos'][atom].columns if 'Energy' not in orbital]
|
|
|
|
# Populate the atom_colours and orbital_colours. Defaults to same colour for all orbitals of one specie.
|
|
if not options['atom_colours']:
|
|
options['palettes'] = [('qualitative', 'Dark2_8')]
|
|
colour_cycle = generate_colours(options=options)
|
|
|
|
for atom in options['plot_atoms']:
|
|
options['atom_colours'].append(next(colour_cycle))
|
|
|
|
if not options['orbital_colours']:
|
|
for i, atom in enumerate(options['plot_orbitals']):
|
|
options['orbital_colours'].append([]) # Make list for specific atom
|
|
for orbital in atom:
|
|
options['orbital_colours'][i].append(options['atom_colours'][i])
|
|
|
|
|
|
for i, atom in enumerate(options['plot_atoms']):
|
|
|
|
if not options['plot_orbitals'] or not options['plot_orbitals'][i]:
|
|
options['plot_orbitals'][i] = [orbital for orbital in data['pdos'][atom].columns if 'Energy' not in orbital]
|
|
|
|
x = 'Energy'
|
|
y = options['plot_orbitals'][i]
|
|
|
|
if options['flip_axes']:
|
|
for j, orbital in enumerate(options['plot_orbitals'][i]):
|
|
|
|
if options['fill']:
|
|
ax.fill_betweenx(y=data['pdos'][atom]['Energy'], x1=data['pdos'][atom][orbital], x2=0, color=options['orbital_colours'][i][j], ec=(0,0,0,1))
|
|
else:
|
|
ax.plot(data['pdos'][atom][orbital], data['pdos'][atom]['Energy'], color=options['orbital_colours'][i][j])
|
|
|
|
|
|
|
|
else:
|
|
data['pdos'][atom].plot(x=x, y=y, color=options['orbital_colours'][i], ax=ax)
|
|
|
|
#print(options['plot_orbitals'], options['orbital_colours'])
|
|
|
|
|
|
if options['flip_axes']:
|
|
|
|
if not options['axes_flipped']:
|
|
options = aux.swap_values(options=options,
|
|
key1=['xlim', 'xunit', 'xlabel', 'x_tick_locators'],
|
|
key2=['ylim', 'yunit', 'ylabel', 'y_tick_locators']
|
|
)
|
|
|
|
options['axes_flipped'] = True #
|
|
|
|
ax.axvline(x=0, c='black')
|
|
|
|
if options['mark_fermi_level']:
|
|
ax.axhline(y=0, c='black', ls='--')
|
|
|
|
|
|
else:
|
|
ax.axhline(y=0, c='black')
|
|
ax.axvline(x=0, c='black', ls='--')
|
|
|
|
fig, ax = btp.adjust_plot(fig=fig, ax=ax, options=options)
|
|
|
|
|
|
|
|
|
|
#elif isinstance(data['pdos'], list):
|
|
# if not options['plot_atoms']:
|
|
# options['plot_atoms'] = data['atoms']['specie']
|
|
#
|
|
# if not options['plot_indices']:
|
|
# for plot_specie in options['plot_atoms']:
|
|
# for i, doscar_specie in enumerate(data['atoms']['specie']):
|
|
# if plot_specie == doscar_specie:
|
|
# options['plot_indices'].append([k for k in range(data['atoms']['number'][i])])
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def plot_partial_dos_legacy(data: dict, options={}):
|
|
|
|
|
|
required_options = ['atoms', 'orbitals', 'up', 'down', 'sum_atoms', 'collapse_spin', 'sum_orbitals', 'palettes', 'colours', 'fig', 'ax']
|
|
|
|
default_options = {
|
|
'atoms': None,
|
|
'orbitals': None,
|
|
'up': True,
|
|
'down': True,
|
|
'sum_atoms': True,
|
|
'collapse_spin': False,
|
|
'sum_orbitals': False,
|
|
'palettes': [('qualitative', 'Dark2_8')],
|
|
'colours': None,
|
|
'fig': None,
|
|
'ax': None
|
|
|
|
|
|
}
|
|
|
|
options = update_options(options=options, required_options=required_options, default_options=default_options)
|
|
|
|
if not options['ax'] and not options['fig']:
|
|
fig, ax = btp.prepare_plot(options)
|
|
else:
|
|
fig, ax = options['fig'], options['ax']
|
|
|
|
species, *_ = dft.io.get_atoms(data['poscar'])
|
|
|
|
pdos, options['dos_info'] = dft.io.read_pdos(data=data, options=options) # Extract projected DOS from DOSCAR, decomposed on individual atoms and orbitals Should yield list of N DataFrames where N is number of atoms in POSCAR
|
|
|
|
|
|
if not options['orbitals']:
|
|
options['orbitals'] = ['s', 'p1', 'p2', 'p3', 'd1', 'd2', 'd3', 'd4', 'd5'] if not options['sum_orbitals'] else ['s', 'p', 'd']
|
|
|
|
if not options['colours']:
|
|
colour_cycle = generate_colours(options=options)
|
|
#
|
|
#colours = []
|
|
#for orbital in options['orbitals']:
|
|
# colours.append(next(colour_cycle))
|
|
#
|
|
# else:
|
|
# colours = options['colours']
|
|
|
|
elif not isinstance(options['colours'], list):
|
|
new_colours = []
|
|
for atom in options['atoms']:
|
|
new_colours.append([options['colours']])
|
|
|
|
options['colours'] = new_colours
|
|
|
|
|
|
print(options['colours'])
|
|
|
|
|
|
if not isinstance(options['orbitals'][0], list):
|
|
new_orbitals = []
|
|
for atom in options['atoms']:
|
|
new_orbitals.append([options['orbitals']])
|
|
|
|
options['orbitals'] = new_orbitals
|
|
|
|
|
|
if options['atoms']:
|
|
for i, atom in enumerate(options['atoms']):
|
|
|
|
if options['sum_atoms']:
|
|
for ind, specie in enumerate(species):
|
|
if specie == atom:
|
|
atom_index = ind
|
|
else:
|
|
atom_index = atom-1
|
|
|
|
|
|
for j, orbital in enumerate(options['orbitals'][i]):
|
|
|
|
colour = options['colours'][i][j]
|
|
|
|
if options['dos_info']['spin_polarised']:
|
|
if options['up']:
|
|
pdos[atom_index].plot(x='Energy', y=orbital+'_u', ax=ax, c=colour)
|
|
|
|
if options['down']:
|
|
pdos[atom_index].plot(x='Energy', y=orbital+'_d', ax=ax, c=colour)
|
|
else:
|
|
pdos[atom_index].plot(x='Energy', y=orbital, ax=ax, c=colour)
|
|
|
|
|
|
|
|
btp.adjust_plot(fig=fig, ax=ax, options=options)
|
|
|
|
return [pdos, ax, fig]
|
|
|
|
|
|
|
|
def get_pdos_indices(poscar, atoms):
|
|
|
|
species, atom_num, atoms_dict = dft.io.get_atoms(poscar)
|
|
|
|
|
|
|
|
|
|
def get_pdos(doscar='DOSCAR', nedos=301, headerlines=6, spin=True, adjust=True, manual_adjust=None):
|
|
|
|
lines = dft.io.open_doscar(doscar)
|
|
|
|
number_of_atoms = dft.io.get_number_of_atoms(doscar)
|
|
|
|
if adjust:
|
|
e_fermi = dft.io.get_fermi_level(doscar) if not manual_adjust else manual_adjust
|
|
else:
|
|
e_fermi = 0
|
|
|
|
pdos = []
|
|
|
|
columns_non_spin = ["Energy", "s", "p_y", "p_z", "p_x", "d_xy", "d_yz", "d_z2-r2", "d_xz", "d_x2-y2"]
|
|
columns_spin = ["Energy", "s_up", "s_down", "p_y_up", "p_y_down", "p_z_up", "p_z_down", "p_x_up", "p_x_down", "d_xy_up", "d_xy_down", "d_yz_up", "d_yz_down",
|
|
"d_z2-r2_up", "d_z2-r2_down", "d_xz_up", "d_xz_down", "d_x2-y2_up", "d_x2-y2_down"]
|
|
|
|
up = ['s_up', "p_y_up", "p_z_up", "p_x_up", "d_xy_up", "d_yz_up", "d_z2-r2_up", "d_xz_up", "d_x2-y2_up"]
|
|
down = ['s_down', "p_y_down", "p_z_down", "p_x_down", "d_xy_down", "d_yz_down", "d_z2-r2_down", "d_xz_down", "d_x2-y2_down"]
|
|
total = ["s", "p_y", "p_z", "p_x", "d_xy", "d_yz", "d_z2-r2", "d_xz", "d_x2-y2"]
|
|
|
|
for i in range(1,number_of_atoms+1):
|
|
atom_dos = []
|
|
|
|
for j in range(headerlines+(nedos*i)+i,nedos+headerlines+(nedos*i)+i):
|
|
line = lines[j].strip()
|
|
values = line.split()
|
|
|
|
for ind, value in enumerate(values):
|
|
values[ind] = float(value)
|
|
|
|
values[0] = values[0] - e_fermi
|
|
atom_dos.append(values)
|
|
|
|
|
|
atom_df = pd.DataFrame(data=atom_dos, columns=columns_non_spin) if spin==False else pd.DataFrame(data=atom_dos, columns=columns_spin)
|
|
|
|
if spin==True:
|
|
atom_df[["s_down"]] = -atom_df[["s_down"]]
|
|
atom_df[["p_y_down"]] = -atom_df[["p_y_down"]]
|
|
atom_df[["p_z_down"]] = -atom_df[["p_z_down"]]
|
|
atom_df[["p_x_down"]] = -atom_df[["p_x_down"]]
|
|
atom_df[["d_xy_down"]] = -atom_df[["d_xy_down"]]
|
|
atom_df[["d_yz_down"]] = -atom_df[["d_yz_down"]]
|
|
atom_df[["d_z2-r2_down"]] = -atom_df[["d_z2-r2_down"]]
|
|
atom_df[["d_xz_down"]] = -atom_df[["d_xz_down"]]
|
|
atom_df[["d_x2-y2_down"]] = -atom_df[["d_x2-y2_down"]]
|
|
|
|
atom_df = atom_df.assign(total_up = atom_df[up].sum(axis=1))
|
|
atom_df = atom_df.assign(total_down = atom_df[down].sum(axis=1))
|
|
|
|
elif spin==False:
|
|
atom_df = atom_df.assign(total = atom_df[total].sum(axis=1))
|
|
|
|
pdos.append(atom_df)
|
|
|
|
return pdos
|
|
|
|
|
|
|
|
def prepare_plot(options={}):
|
|
|
|
rc_params = options['rc_params']
|
|
format_params = options['format_params']
|
|
|
|
required_options = ['single_column_width', 'double_column_width', 'column_type', 'width_ratio', 'aspect_ratio', 'compress_width', 'compress_height', 'upscaling_factor', 'dpi']
|
|
|
|
default_options = {
|
|
'single_column_width': 8.3,
|
|
'double_column_width': 17.1,
|
|
'column_type': 'single',
|
|
'width_ratio': '1:1',
|
|
'aspect_ratio': '1:1',
|
|
'compress_width': 1,
|
|
'compress_height': 1,
|
|
'upscaling_factor': 1.0,
|
|
'dpi': 600,
|
|
}
|
|
|
|
options = update_options(format_params, required_options, default_options)
|
|
|
|
|
|
# Reset run commands
|
|
plt.rcdefaults()
|
|
|
|
# Update run commands if any is passed (will pass an empty dictionary if not passed)
|
|
update_rc_params(rc_params)
|
|
|
|
width = determine_width(options)
|
|
height = determine_height(options, width)
|
|
width, height = scale_figure(options=options, width=width, height=height)
|
|
|
|
fig, ax = plt.subplots(figsize=(width, height), dpi=options['dpi'])
|
|
|
|
return fig, ax
|
|
|
|
|
|
def prepare_dos_plot(width=None, height=None, square=True, dpi=None, colour_cycle=('qualitative', 'Dark2_8'), energyunit='eV', dosunit='arb. u.', scale=1, pdos=None):
|
|
|
|
if not pdos:
|
|
linewidth = 3*scale
|
|
else:
|
|
linewidth = 3
|
|
|
|
axeswidth = 3*scale
|
|
|
|
plt.rc('lines', linewidth=linewidth)
|
|
plt.rc('axes', linewidth=axeswidth)
|
|
|
|
if square:
|
|
if not width:
|
|
width = 20
|
|
|
|
if not height:
|
|
height = width
|
|
|
|
|
|
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(width, height), facecolor='w', dpi=dpi)
|
|
|
|
|
|
return fig, ax
|
|
|
|
def prettify_dos_plot(fig, ax, options):
|
|
|
|
|
|
required_options = ['plot_kind', 'flip_xy', 'hide_x_labels', 'hide_y_labels', 'xlabel', 'ylabel', 'xunit', 'yunit', 'xlim', 'ylim', 'x_tick_locators', 'y_tick_locators', 'y_tick_format', 'x_tick_format', 'hide_x_ticks', 'hide_y_ticks', 'hide_x_ticklabels', 'hide_y_ticklabels',
|
|
'colours', 'palettes', 'title', 'legend', 'labels', 'label_colours', 'legend_position', 'legend_ncol', 'subplots_adjust', 'text']
|
|
|
|
default_options = {
|
|
'plot_kind': 'PDOS', # DOS/PDOS/COOP/COHP
|
|
'flip_xy': False,
|
|
'hide_x_labels': False, # Whether x labels should be hidden
|
|
'hide_x_ticklabels': False,
|
|
'hide_x_ticks': False,
|
|
'hide_y_labels': False, # whether y labels should be hidden
|
|
'hide_y_ticklabels': False,
|
|
'hide_y_ticks': False,
|
|
'xlabel': 'Energy',
|
|
'ylabel': 'DOS',
|
|
'xunit': r'eV', # The unit of the x-values in the curve plot
|
|
'yunit': r'a.u.', # The unit of the y-values in the curve and bar plots
|
|
'xlim': None,
|
|
'ylim': None,
|
|
'x_tick_locators': [1, .5], # Major and minor tick locators
|
|
'y_tick_locators': [1, .5],
|
|
'y_tick_format': None,
|
|
'x_tick_format': None,
|
|
'colours': None,
|
|
'palettes': [('qualitative', 'Dark2_8'), ('qualitative', 'Paired_12')],
|
|
'title': None,
|
|
'legend': True,
|
|
'labels': None,
|
|
'label_colours': None,
|
|
'legend_position': ['upper center', (0.20, 0.90)], # the position of the legend passed as arguments to loc and bbox_to_anchor respectively
|
|
'legend_ncol': 1,
|
|
'subplots_adjust': [0.1, 0.1, 0.9, 0.9],
|
|
'text': None
|
|
}
|
|
|
|
|
|
if 'plot_kind' in options.keys():
|
|
if 'ylabel' not in options.keys():
|
|
if options['plot_kind'] == 'DOS':
|
|
options['ylabel'] = 'DOS'
|
|
elif options['plot_kind'] == 'PDOS':
|
|
options['ylabel'] = 'PDOS'
|
|
elif options['plot_kind'] == 'COOP':
|
|
options['ylabel'] = 'COOP'
|
|
elif options['plot_kind'] == 'COHP':
|
|
options['ylabel'] = 'COHP'
|
|
|
|
|
|
|
|
options = update_options(options=options, required_options=required_options, default_options=default_options)
|
|
|
|
|
|
if options['flip_xy']:
|
|
|
|
# Switch all the x- and y-specific values
|
|
options = aux.swap_values(dict=options, key1='xlim', key2='ylim')
|
|
options = aux.swap_values(dict=options, key1='xunit', key2='yunit')
|
|
options = aux.swap_values(dict=options, key1='xlabel', key2='ylabel')
|
|
options = aux.swap_values(dict=options, key1='x_tick_locators', key2='y_tick_locators')
|
|
options = aux.swap_values(dict=options, key1='hide_x_labels', key2='hide_y_labels')
|
|
|
|
# Set labels on x- and y-axes
|
|
if not options['hide_y_labels']:
|
|
ax.set_ylabel(f'{options["ylabel"]} [{options["yunit"]}]')
|
|
else:
|
|
ax.set_ylabel('')
|
|
|
|
|
|
|
|
if not options['hide_x_labels']:
|
|
ax.set_xlabel(f'{options["xlabel"]} [{options["xunit"]}]')
|
|
else:
|
|
ax.set_xlabel('')
|
|
|
|
|
|
# Hide x- and y- ticklabels
|
|
if options['hide_y_ticklabels']:
|
|
ax.tick_params(axis='y', direction='in', which='both', labelleft=False, labelright=False)
|
|
if options['hide_x_ticklabels']:
|
|
ax.tick_params(axis='x', direction='in', which='both', labelbottom=False, labeltop=False)
|
|
|
|
|
|
# Hide x- and y-ticks:
|
|
if options['hide_y_ticks']:
|
|
ax.tick_params(axis='y', direction='in', which='both', left=False, right=False)
|
|
if options['hide_x_ticks']:
|
|
ax.tick_params(axis='x', direction='in', which='both', bottom=False, top=False)
|
|
|
|
|
|
|
|
# Set multiple locators
|
|
ax.yaxis.set_major_locator(MultipleLocator(options['y_tick_locators'][0]))
|
|
ax.yaxis.set_minor_locator(MultipleLocator(options['y_tick_locators'][1]))
|
|
|
|
ax.xaxis.set_major_locator(MultipleLocator(options['x_tick_locators'][0]))
|
|
ax.xaxis.set_minor_locator(MultipleLocator(options['x_tick_locators'][1]))
|
|
|
|
# Change format of axis tick labels if specified:
|
|
|
|
|
|
|
|
# Set title
|
|
if options['title']:
|
|
ax.set_title(options['title'])
|
|
|
|
|
|
if options['y_tick_format']:
|
|
ax.yaxis.set_major_formatter(FormatStrFormatter(options['y_tick_format']))
|
|
if options['x_tick_format']:
|
|
ax.xaxis.set_major_formatter(FormatStrFormatter(options['x_tick_format']))
|
|
|
|
|
|
# Create legend
|
|
|
|
if ax.get_legend():
|
|
ax.get_legend().remove()
|
|
|
|
|
|
if options['legend'] and options['labels']:
|
|
|
|
|
|
# Generate colours
|
|
if not options['colours']:
|
|
colour_cycle = generate_colours(palettes=options['palettes'])
|
|
|
|
colours = []
|
|
for label in options['labels']:
|
|
colours.append(next(colour_cycle))
|
|
|
|
|
|
else:
|
|
colours = options['colours']
|
|
|
|
if options['label_colours']:
|
|
colours = options['label_colours']
|
|
|
|
# Create legend
|
|
patches = []
|
|
for i, label in enumerate(options['labels']):
|
|
patches.append(mpatches.Patch(color=colours[i], label=label))
|
|
|
|
print(options['legend_ncol'])
|
|
|
|
ax.legend(handles=patches, loc=options['legend_position'][0], bbox_to_anchor=options['legend_position'][1], frameon=False, ncol=options['legend_ncol'])
|
|
|
|
|
|
|
|
# Adjust where the axes start within the figure. Default value is 10% in from the left and bottom edges. Used to make room for the plot within the figure size (to avoid using bbox_inches='tight' in the savefig-command, as this screws with plot dimensions)
|
|
plt.subplots_adjust(left=options['subplots_adjust'][0], bottom=options['subplots_adjust'][1], right=options['subplots_adjust'][2], top=options['subplots_adjust'][3])
|
|
|
|
|
|
# If limits for x- and y-axes is passed, sets these.
|
|
if options['xlim']:
|
|
ax.set_xlim(options['xlim'])
|
|
|
|
if options['ylim']:
|
|
ax.set_ylim(options['ylim'])
|
|
|
|
|
|
# Add custom text
|
|
if options['text']:
|
|
plt.text(x=options['text'][1][0], y=options['text'][1][1], s=options['text'][0])
|
|
|
|
|
|
|
|
if options['e_fermi']:
|
|
if options['flip_xy']:
|
|
ax.axhline(0, c='black', ls='dashed')
|
|
else:
|
|
ax.axvline(0, c='black', ls='dashed')
|
|
|
|
if options['plot_kind'] == 'DOS' or options['plot_kind'] == 'PDOS':
|
|
if options['dos_info']['spin_polarised']:
|
|
if options['flip_xy']:
|
|
ax.axvline(0, c='black')
|
|
else:
|
|
ax.axhline(0, c='black')
|
|
elif options['plot_kind'] == 'COOP' or options['plot_kind'] == 'COHP':
|
|
if options['flip_xy']:
|
|
ax.axvline(0, c='black')
|
|
else:
|
|
ax.axhline(0, c='black')
|
|
|
|
return fig, ax
|
|
|
|
|
|
|
|
def plot_coop(data, options):
|
|
''' interactions = list with number of interaction (index + 1 of interactions list from read_coop)'''
|
|
|
|
|
|
default_options = {
|
|
'plot_kind': 'COOP',
|
|
'mode': 'individual',
|
|
'fill': False,
|
|
'up': True,
|
|
'down': True,
|
|
'collapse': False,
|
|
'interactions': None,
|
|
'palettes': [('qualitative', 'Dark2_8')],
|
|
'colours': None,
|
|
'flip_xy': False
|
|
|
|
}
|
|
|
|
options = aux.update_options(options=options, default_options=default_options)
|
|
|
|
|
|
fig, ax = btp.prepare_plot(options=options)
|
|
|
|
coopcar, coop_interactions = dft.io.read_coop(data=data, options=options)
|
|
|
|
|
|
|
|
if not options['colours']:
|
|
colour_cycle = btp.generate_colours(palettes=options['palettes'])
|
|
|
|
colours = []
|
|
for interaction in range(len(coop_interactions)):
|
|
colours.append(next(colour_cycle))
|
|
|
|
else:
|
|
colours = options['colours']
|
|
|
|
|
|
# If interactions has been specified
|
|
if options['interactions']:
|
|
|
|
# Make interactions into a list of lists for correct looping below
|
|
if type(options['interactions'][0]) != list:
|
|
interactions_list = [options['interactions']]
|
|
else:
|
|
interactions_list = options['interactions']
|
|
|
|
for ind, interactions in enumerate(interactions_list):
|
|
|
|
# Determine which columns to plot if collapse is enabled
|
|
if options['collapse']:
|
|
to_plot = [2*(i-1)+3 for i in interactions]
|
|
|
|
# Make sum column for plotting if sum mode is enabled
|
|
if options['mode'] == 'sum':
|
|
coopcar["sum"] = coopcar.iloc[:, to_plot].sum(axis=1)
|
|
to_plot = ['sum']
|
|
|
|
|
|
# Make mean column for plotting if mean mode is enabeld
|
|
elif options['mode'] == 'avg' or options['mode'] == 'average' or options['mode'] == 'mean':
|
|
coopcar["mean"] = coopcar.iloc[:, to_plot].mean(axis=1)
|
|
to_plot = ['mean']
|
|
|
|
# Determine which columns to plot if collapse is disabled and both up and down should be plotted
|
|
elif options['up'] and options['down']:
|
|
to_plot_up = [2*(i-1)+3 for i in interactions]
|
|
to_plot_down = [2*(i-1)+5 +2*len(coop_interactions) for i in interactions]
|
|
to_plot = to_plot_up + to_plot_down
|
|
|
|
if options['mode'] == 'sum':
|
|
coopcar["sum_up"] = coopcar.iloc[:, to_plot_up].sum(axis=1)
|
|
coopcar["sum_down"] = coopcar.iloc[:, to_plot_down].sum(axis=1)
|
|
to_plot = ['sum_up', 'sum_down']
|
|
|
|
elif options['mode'] == 'avg' or options['mode'] == 'average' or options['mode'] == 'mean':
|
|
coopcar["mean_up"] = coopcar.iloc[:, to_plot_up].mean(axis=1)
|
|
coopcar["mean_down"] = coopcar.iloc[:, to_plot_down].mean(axis=1)
|
|
to_plot = ['mean_up', 'mean_down']
|
|
|
|
# Determine which columns to plot if collapse is disabled and only up should be plotted
|
|
elif options['up']:
|
|
to_plot = [2*(i-1)+3 for i in interactions]
|
|
|
|
if options['mode'] == 'sum':
|
|
coopcar["sum_up"] = coopcar.iloc[:, to_plot].sum(axis=1)
|
|
to_plot = ['sum_up']
|
|
|
|
elif options['mode'] == 'avg' or options['mode'] == 'average' or options['mode'] == 'mean':
|
|
coopcar["mean_up"] = coopcar.iloc[:, to_plot].mean(axis=1)
|
|
to_plot = ['mean_up']
|
|
|
|
|
|
# Determine which columns to plot if collapse is disabled and only down should be plotted
|
|
elif options['down']:
|
|
to_plot = [2*(i-1)+5 +2*len(coop_interactions) for i in interactions]
|
|
|
|
if options['mode'] == 'sum':
|
|
coopcar["sum_down"] = coopcar.iloc[:, to_plot].sum(axis=1)
|
|
to_plot = ['sum_down']
|
|
|
|
elif options['mode'] == 'avg' or options['mode'] == 'average' or options['mode'] == 'mean':
|
|
coopcar["mean_down"] = coopcar.iloc[:, to_plot].mean(axis=1)
|
|
to_plot = ['mean_down']
|
|
|
|
|
|
|
|
|
|
# Plot all columns as decided above
|
|
for j, column in enumerate(to_plot):
|
|
if options['fill']:
|
|
ax.fill_between(coopcar["Energy"], coopcar[column], 0, where=coopcar[column]>0, color=colours[ind])
|
|
ax.fill_between(coopcar["Energy"], coopcar[column], 0, where=coopcar[column]<0, color=colours[ind+1])
|
|
|
|
else:
|
|
if options['mode'] == "individual":
|
|
colour = colours[j]
|
|
else:
|
|
colour = colours[ind]
|
|
|
|
|
|
if options['flip_xy']:
|
|
coopcar.plot(y='Energy', x=column, ax=ax, color=colour)
|
|
else:
|
|
coopcar.plot(x='Energy', y=column, ax=ax, color=colour)
|
|
|
|
fig, ax = btp.adjust_plot(fig=fig, ax=ax, options=options)
|
|
|
|
return coopcar, fig, ax
|
|
|
|
|
|
|
|
def prettify_coop_plot(fig, ax, energyunit='eV', dosunit='arb. u.', xlim=None, ylim=None, title=None, hide_ylabels=False, hide_xlabels=False, hide_yvals=False, hide_xvals=False, flip_xy=False, pad_bottom=None, scale=1, colours=None, atoms=None, pdos=False, width=None, height=None, e_fermi=False, adjust=False, legend=False, labels=None, label_colours=None, xpad=0, ypad=0):
|
|
|
|
# Set sizes of ticks, labes etc.
|
|
ticksize = 30*scale
|
|
labelsize = 30*scale
|
|
legendsize = 30*scale
|
|
titlesize = 30*scale
|
|
|
|
linewidth = 3*scale
|
|
axeswidth = 3*scale
|
|
majorticklength = 20*scale
|
|
minorticklength = 10*scale
|
|
|
|
plt.xticks(fontsize=ticksize)
|
|
plt.yticks(fontsize=ticksize)
|
|
|
|
if flip_xy:
|
|
|
|
# Set labels on x- and y-axes
|
|
if not hide_ylabels:
|
|
if ypad:
|
|
ax.set_ylabel('Energy [{}]'.format(energyunit), size=labelsize, labelpad=ypad)
|
|
else:
|
|
ax.set_ylabel('Energy [{}]'.format(energyunit), size=labelsize)
|
|
|
|
if pdos:
|
|
if xpad:
|
|
ax.set_xlabel('COOP [{}]'.format(dosunit), size=labelsize, labelpad=xpad)
|
|
else:
|
|
ax.set_xlabel('COOP [{}]'.format(dosunit), size=labelsize)
|
|
|
|
else:
|
|
if width >= 10:
|
|
if xpad:
|
|
ax.set_xlabel('COOP [{}]'.format(dosunit), size=labelsize, labelpad=xpad)
|
|
else:
|
|
ax.set_xlabel('COOP [{}]'.format(dosunit), size=labelsize)
|
|
|
|
else:
|
|
if xpad:
|
|
ax.set_xlabel('COOP [{}]'.format(dosunit), size=labelsize, labelpad=xpad)
|
|
else:
|
|
ax.set_xlabel('COOP [{}]'.format(dosunit), size=labelsize)
|
|
|
|
ax.tick_params(axis='y', direction='in', which='major', right=True, length=majorticklength, width=linewidth)
|
|
ax.tick_params(axis='y', direction='in', which='minor', right=True, length=minorticklength, width=linewidth)
|
|
|
|
if hide_yvals:
|
|
ax.tick_params(axis='y', labelleft=False)
|
|
|
|
ax.tick_params(axis='x', direction='in', which='major', bottom=False, labelbottom=False)
|
|
|
|
ax.yaxis.set_major_locator(MultipleLocator(1))
|
|
ax.yaxis.set_minor_locator(MultipleLocator(.5))
|
|
|
|
|
|
|
|
else:
|
|
# Set labels on x- and y-axes
|
|
if adjust:
|
|
if xpad:
|
|
ax.set_xlabel('E - E$_F$ [{}]'.format(energyunit), size=labelsize, labelpad=xpad)
|
|
else:
|
|
ax.set_xlabel('E - E$_F$ [{}]'.format(energyunit), size=labelsize)
|
|
|
|
|
|
else:
|
|
if xpad:
|
|
ax.set_xlabel('Energy [{}]'.format(energyunit), size=labelsize, labelpad=xpad)
|
|
else:
|
|
ax.set_xlabel('Energy [{}]'.format(energyunit), size=labelsize)
|
|
|
|
|
|
if height < 10:
|
|
if ypad:
|
|
ax.set_ylabel('COOP [{}]'.format(dosunit), size=labelsize, labelpad=ypad)
|
|
else:
|
|
ax.set_ylabel('COOP [{}]'.format(dosunit), size=labelsize)
|
|
|
|
else:
|
|
if ypad:
|
|
ax.set_ylabel('Crystal orbital overlap population [{}]'.format(dosunit), size=labelsize, labelpad=ypad)
|
|
else:
|
|
ax.set_ylabel('Crystal orbital overlap population [{}]'.format(dosunit), size=labelsize)
|
|
|
|
|
|
ax.tick_params(axis='x', direction='in', which='major', bottom=True, top=True, length=majorticklength, width=linewidth)
|
|
ax.tick_params(axis='x', direction='in', which='minor', bottom=True, top=True, length=minorticklength, width=linewidth)
|
|
|
|
|
|
ax.tick_params(axis='y', which='major', direction='in', right=True, left=True, labelleft=True, length=majorticklength, width=linewidth)
|
|
ax.tick_params(axis='y', which='minor', direction='in', right=True, left=True, length=minorticklength, width=linewidth)
|
|
|
|
if hide_ylabels:
|
|
ax.set_ylabel('')
|
|
if hide_xlabels:
|
|
ax.set_xlabel('')
|
|
if hide_yvals:
|
|
ax.tick_params(axis='y', which='both', labelleft=False)
|
|
if hide_xvals:
|
|
ax.tick_params(axis='x', which='both', labelbottom=False)
|
|
|
|
|
|
if ylim:
|
|
yspan = ylim[1] - ylim[0]
|
|
yloc = np.round(yspan / 4, 2)
|
|
|
|
ax.yaxis.set_major_locator(MultipleLocator(yloc))
|
|
ax.yaxis.set_minor_locator(MultipleLocator(yloc/2))
|
|
|
|
|
|
|
|
ax.xaxis.set_major_locator(MultipleLocator(1))
|
|
ax.xaxis.set_minor_locator(MultipleLocator(.5))
|
|
|
|
|
|
|
|
plt.xlim(xlim)
|
|
plt.ylim(ylim)
|
|
|
|
|
|
if title:
|
|
ax.set_title(title, size=40)
|
|
|
|
|
|
|
|
if legend:
|
|
patches = []
|
|
|
|
if label_colours:
|
|
colours=label_colours
|
|
|
|
for ind, label in enumerate(labels):
|
|
patches.append(mpatches.Patch(color=colours[ind], label=label))
|
|
|
|
fig.legend(handles=patches, loc='upper right', ncol=len(labels), bbox_to_anchor=(0.8, 0.45), fontsize=legendsize/1.25, frameon=False)
|
|
|
|
#bbox_to_anchor=(1.20, 0.91)
|
|
|
|
|
|
|
|
if pad_bottom is not None:
|
|
bigax = fig.add_subplot(111)
|
|
bigax.set_facecolor([1,1,1,0])
|
|
bigax.spines['top'].set_visible(False)
|
|
bigax.spines['bottom'].set_visible(True)
|
|
bigax.spines['left'].set_visible(False)
|
|
bigax.spines['right'].set_visible(False)
|
|
bigax.tick_params(labelcolor='w', color='w', direction='in', top=False, bottom=True, left=False, right=False, labelleft=False, pad=pad_bottom)
|
|
|
|
|
|
|
|
if xpad:
|
|
ax.tick_params(axis='x', pad=xpad)
|
|
|
|
if ypad:
|
|
ax.tick_params(axis='y', pad=ypad)
|
|
|
|
if e_fermi:
|
|
if flip_xy:
|
|
plt.axhline(0, lw=linewidth, c='black', ls='dashed')
|
|
else:
|
|
plt.axvline(0, lw=linewidth, c='black', ls='dashed')
|
|
|
|
|
|
plt.axhline(0, lw=linewidth, c='black')
|
|
|
|
return fig, ax
|
|
|
|
|
|
|
|
|
|
def get_unique_atoms(interactions):
|
|
''' Get all the unique atoms involved in the interactions from the COOP-calculation
|
|
|
|
Input:
|
|
interactions: list of interactions that comes as output from read_coop()
|
|
|
|
Outut:
|
|
unique_atoms: list of unique atoms in the interactions list'''
|
|
|
|
unique_atoms = []
|
|
|
|
for interaction in interactions:
|
|
|
|
atoms = interaction.split('->')
|
|
|
|
for atom in atoms:
|
|
if atom not in unique_atoms:
|
|
unique_atoms.append(atom)
|
|
|
|
|
|
unique_atoms.sort()
|
|
|
|
return unique_atoms
|
|
|
|
|
|
def get_interactions_involving(interactions, targets):
|
|
''' Get the indicies (+1) of all the interactions involving target. This list can be used as input to plot_coop(), as it is
|
|
then formatted the way that function accepts these interactions.
|
|
|
|
Input:
|
|
interactions: list of interactions as output from read_coop()
|
|
target: the particular atom that should be involved in the interactions contained in the output list
|
|
|
|
Output:
|
|
target_interactions: Indices (+1) of all the interactions involving target atom.'''
|
|
|
|
target_interactions = []
|
|
appended_interactions = []
|
|
|
|
|
|
if type(targets) == list:
|
|
for target in targets:
|
|
for ind, interaction in enumerate(interactions):
|
|
if target in interaction.split('->') and interaction not in appended_interactions:
|
|
target_interactions.append(ind+1)
|
|
appended_interactions.append(interaction)
|
|
|
|
else:
|
|
for ind, interaction in enumerate(interactions):
|
|
if targets in interaction.split('->'):
|
|
target_interactions.append(ind+1)
|
|
|
|
|
|
return target_interactions
|
|
|
|
|
|
|
|
|
|
def update_rc_params(rc_params):
|
|
''' Update all passed run commands in matplotlib'''
|
|
|
|
if rc_params:
|
|
for key in rc_params.keys():
|
|
plt.rcParams.update({key: rc_params[key]})
|
|
|
|
|
|
|
|
def update_options(options, required_options, default_options):
|
|
''' Update all passed options'''
|
|
|
|
|
|
for option in required_options:
|
|
if option not in options.keys():
|
|
options[option] = default_options[option]
|
|
|
|
|
|
|
|
return options
|
|
|
|
|
|
|
|
def determine_width(options):
|
|
|
|
conversion_cm_inch = 0.3937008 # cm to inch
|
|
|
|
if options['column_type'] == 'single':
|
|
column_width = options['single_column_width']
|
|
elif options['column_type'] == 'double':
|
|
column_width = options['double_column_width']
|
|
|
|
column_width *= conversion_cm_inch
|
|
|
|
|
|
width_ratio = [float(num) for num in options['width_ratio'].split(':')]
|
|
|
|
|
|
width = column_width * width_ratio[0]/width_ratio[1]
|
|
|
|
|
|
return width
|
|
|
|
|
|
def determine_height(options, width):
|
|
|
|
aspect_ratio = [float(num) for num in options['aspect_ratio'].split(':')]
|
|
|
|
height = width/(aspect_ratio[0] / aspect_ratio[1])
|
|
|
|
return height
|
|
|
|
|
|
def scale_figure(options, width, height):
|
|
width = width * options['upscaling_factor'] * options['compress_width']
|
|
height = height * options['upscaling_factor'] * options['compress_height']
|
|
|
|
return width, height
|
|
|
|
|
|
|
|
def generate_colours(options: dict):
|
|
|
|
if not isinstance(options['palettes'], list):
|
|
options['palettes'] = [options['palettes']]
|
|
|
|
# Creates a list of all the colours that is passed in the colour_cycles argument. Then makes cyclic iterables of these.
|
|
colour_collection = []
|
|
|
|
for palette in options['palettes']:
|
|
mod = importlib.import_module("palettable.colorbrewer.%s" % palette[0])
|
|
colour = getattr(mod, palette[1]).mpl_colors
|
|
colour_collection = colour_collection + colour
|
|
|
|
colour_cycle = itertools.cycle(colour_collection)
|
|
|
|
|
|
return colour_cycle |