Rewrite plot_pdos

This commit is contained in:
rasmusthog 2022-10-09 18:38:00 +02:00
parent 3f1d1e4d1f
commit f72bd4e77f

View file

@ -242,7 +242,136 @@ def integrate_coop(coopcar, interval=None, r=None, scale=None, interactions=None
def plot_partial_dos(data: dict, options={}): def plot_pdos(data: dict, options={}):
default_options = {
'xlabel': 'Energy', 'xunit': 'eV', 'xlim': None, 'x_tick_locators': None,
'ylabel': 'Partial density of states', 'yunit': 'arb.u.', 'ylim': None, 'y_tick_locators': None,
'mark_fermi_level': True, # Adds a dashed line to mark the Fermi-level
'flip_axes': False, # Flips x- and y-axes
'plot_indices': [], # List which indices to plot. If options["sum_atoms"] == True, this needs to be a list of lists, each specifying the index of a given atom
'plot_atoms': [], # List of which atoms to plot. Only used if options["sum_atoms"] == True.
'plot_orbitals': [], # List of which orbitals to plot. If options["sum_atoms"] == True, this needs to be a list of lists, each specifying the orbitals of a given atom
'atom_colours': [], # Colours of each atom. Should be a colour for each atom, only in use if options["sum_atoms"] == True.
'orbital_colours': [], # Colours of each orbital. The list should always correspond to the shape of options["plot_orbitals"].
'fill': False,
'fig': None, # Matplotlib Figure object
'ax': None, # Matplotlib Axes object
}
options = aux.update_options(options=options, default_options=default_options)
if 'axes_flipped' not in options.keys():
options['axes_flipped'] = False
data = dft.io.read_pdos(data=data, options=options)
if not options['fig'] and not options['ax']:
fig, ax = btp.prepare_plot(options=options)
else:
fig, ax = options['fig'], options['ax']
# If options['sum_atoms'] == True
if isinstance(data['pdos'], dict):
# Populate the plot_atoms and plot_orbitals lists if they are not passed. Defaults to showing everything
if not options['plot_atoms']:
options['plot_atoms'] = data['atoms']['specie']
if not options['plot_orbitals']:
for atom in options['plot_atoms']:
options['plot_orbitals'].append([])
# This is to fill in each orbital list for each atom. This is in case options['plot_orbitals'] is passes, but one or more of the atoms lack colours
for i, atom in enumerate(options['plot_atoms']):
if not options['plot_orbitals'] or not options['plot_orbitals'][i]:
options['plot_orbitals'][i] = [orbital for orbital in data['pdos'][atom].columns if 'Energy' not in orbital]
# Populate the atom_colours and orbital_colours. Defaults to same colour for all orbitals of one specie.
if not options['atom_colours']:
options['palettes'] = [('qualitative', 'Dark2_8')]
colour_cycle = generate_colours(options=options)
for atom in options['plot_atoms']:
options['atom_colours'].append(next(colour_cycle))
if not options['orbital_colours']:
for i, atom in enumerate(options['plot_orbitals']):
options['orbital_colours'].append([]) # Make list for specific atom
for orbital in atom:
options['orbital_colours'][i].append(options['atom_colours'][i])
for i, atom in enumerate(options['plot_atoms']):
if not options['plot_orbitals'] or not options['plot_orbitals'][i]:
options['plot_orbitals'][i] = [orbital for orbital in data['pdos'][atom].columns if 'Energy' not in orbital]
x = 'Energy'
y = options['plot_orbitals'][i]
if options['flip_axes']:
for j, orbital in enumerate(options['plot_orbitals'][i]):
if options['fill']:
ax.fill_betweenx(y=data['pdos'][atom]['Energy'], x1=data['pdos'][atom][orbital], x2=0, color=options['orbital_colours'][i][j], alpha=0.5, ec=(0,0,0,0))
else:
ax.plot(data['pdos'][atom][orbital], data['pdos'][atom]['Energy'], color=options['orbital_colours'][i][j])
else:
data['pdos'][atom].plot(x=x, y=y, color=options['orbital_colours'][i], ax=ax)
#print(options['plot_orbitals'], options['orbital_colours'])
if options['flip_axes']:
if not options['axes_flipped']:
options = aux.swap_values(options=options,
key1=['xlim', 'xunit', 'xlabel', 'x_tick_locators'],
key2=['ylim', 'yunit', 'ylabel', 'y_tick_locators']
)
options['axes_flipped'] = True #
ax.axvline(x=0, c='black')
if options['mark_fermi_level']:
ax.axhline(y=0, c='black', ls='--')
else:
ax.axhline(y=0, c='black')
ax.axvline(x=0, c='black', ls='--')
fig, ax = btp.adjust_plot(fig=fig, ax=ax, options=options)
#elif isinstance(data['pdos'], list):
# if not options['plot_atoms']:
# options['plot_atoms'] = data['atoms']['specie']
#
# if not options['plot_indices']:
# for plot_specie in options['plot_atoms']:
# for i, doscar_specie in enumerate(data['atoms']['specie']):
# if plot_specie == doscar_specie:
# options['plot_indices'].append([k for k in range(data['atoms']['number'][i])])
return None
def plot_partial_dos_legacy(data: dict, options={}):
required_options = ['atoms', 'orbitals', 'up', 'down', 'sum_atoms', 'collapse_spin', 'sum_orbitals', 'palettes', 'colours', 'fig', 'ax'] required_options = ['atoms', 'orbitals', 'up', 'down', 'sum_atoms', 'collapse_spin', 'sum_orbitals', 'palettes', 'colours', 'fig', 'ax']
@ -1080,16 +1209,6 @@ def scale_figure(options, width, height):
def swap_values(dict, key1, key2):
key1_val = dict[key1]
dict[key1] = dict[key2]
dict[key2] = key1_val
return dict
def generate_colours(options: dict): def generate_colours(options: dict):
if not isinstance(options['palettes'], list): if not isinstance(options['palettes'], list):