diff --git a/nafuma/dft/electrons.py b/nafuma/dft/electrons.py index d96d2ad..8ea1064 100644 --- a/nafuma/dft/electrons.py +++ b/nafuma/dft/electrons.py @@ -242,7 +242,136 @@ def integrate_coop(coopcar, interval=None, r=None, scale=None, interactions=None -def plot_partial_dos(data: dict, options={}): +def plot_pdos(data: dict, options={}): + + default_options = { + 'xlabel': 'Energy', 'xunit': 'eV', 'xlim': None, 'x_tick_locators': None, + 'ylabel': 'Partial density of states', 'yunit': 'arb.u.', 'ylim': None, 'y_tick_locators': None, + 'mark_fermi_level': True, # Adds a dashed line to mark the Fermi-level + 'flip_axes': False, # Flips x- and y-axes + 'plot_indices': [], # List which indices to plot. If options["sum_atoms"] == True, this needs to be a list of lists, each specifying the index of a given atom + 'plot_atoms': [], # List of which atoms to plot. Only used if options["sum_atoms"] == True. + 'plot_orbitals': [], # List of which orbitals to plot. If options["sum_atoms"] == True, this needs to be a list of lists, each specifying the orbitals of a given atom + 'atom_colours': [], # Colours of each atom. Should be a colour for each atom, only in use if options["sum_atoms"] == True. + 'orbital_colours': [], # Colours of each orbital. The list should always correspond to the shape of options["plot_orbitals"]. + 'fill': False, + 'fig': None, # Matplotlib Figure object + 'ax': None, # Matplotlib Axes object + } + + options = aux.update_options(options=options, default_options=default_options) + + if 'axes_flipped' not in options.keys(): + options['axes_flipped'] = False + + data = dft.io.read_pdos(data=data, options=options) + + + + if not options['fig'] and not options['ax']: + fig, ax = btp.prepare_plot(options=options) + else: + fig, ax = options['fig'], options['ax'] + + + # If options['sum_atoms'] == True + if isinstance(data['pdos'], dict): + + # Populate the plot_atoms and plot_orbitals lists if they are not passed. Defaults to showing everything + if not options['plot_atoms']: + options['plot_atoms'] = data['atoms']['specie'] + + if not options['plot_orbitals']: + for atom in options['plot_atoms']: + options['plot_orbitals'].append([]) + + # This is to fill in each orbital list for each atom. This is in case options['plot_orbitals'] is passes, but one or more of the atoms lack colours + for i, atom in enumerate(options['plot_atoms']): + if not options['plot_orbitals'] or not options['plot_orbitals'][i]: + options['plot_orbitals'][i] = [orbital for orbital in data['pdos'][atom].columns if 'Energy' not in orbital] + + # Populate the atom_colours and orbital_colours. Defaults to same colour for all orbitals of one specie. + if not options['atom_colours']: + options['palettes'] = [('qualitative', 'Dark2_8')] + colour_cycle = generate_colours(options=options) + + for atom in options['plot_atoms']: + options['atom_colours'].append(next(colour_cycle)) + + if not options['orbital_colours']: + for i, atom in enumerate(options['plot_orbitals']): + options['orbital_colours'].append([]) # Make list for specific atom + for orbital in atom: + options['orbital_colours'][i].append(options['atom_colours'][i]) + + + for i, atom in enumerate(options['plot_atoms']): + + if not options['plot_orbitals'] or not options['plot_orbitals'][i]: + options['plot_orbitals'][i] = [orbital for orbital in data['pdos'][atom].columns if 'Energy' not in orbital] + + x = 'Energy' + y = options['plot_orbitals'][i] + + if options['flip_axes']: + for j, orbital in enumerate(options['plot_orbitals'][i]): + + if options['fill']: + ax.fill_betweenx(y=data['pdos'][atom]['Energy'], x1=data['pdos'][atom][orbital], x2=0, color=options['orbital_colours'][i][j], alpha=0.5, ec=(0,0,0,0)) + else: + ax.plot(data['pdos'][atom][orbital], data['pdos'][atom]['Energy'], color=options['orbital_colours'][i][j]) + + + + else: + data['pdos'][atom].plot(x=x, y=y, color=options['orbital_colours'][i], ax=ax) + + #print(options['plot_orbitals'], options['orbital_colours']) + + + if options['flip_axes']: + + if not options['axes_flipped']: + options = aux.swap_values(options=options, + key1=['xlim', 'xunit', 'xlabel', 'x_tick_locators'], + key2=['ylim', 'yunit', 'ylabel', 'y_tick_locators'] + ) + + options['axes_flipped'] = True # + + ax.axvline(x=0, c='black') + + if options['mark_fermi_level']: + ax.axhline(y=0, c='black', ls='--') + + + else: + ax.axhline(y=0, c='black') + ax.axvline(x=0, c='black', ls='--') + + fig, ax = btp.adjust_plot(fig=fig, ax=ax, options=options) + + + + + #elif isinstance(data['pdos'], list): + # if not options['plot_atoms']: + # options['plot_atoms'] = data['atoms']['specie'] + # + # if not options['plot_indices']: + # for plot_specie in options['plot_atoms']: + # for i, doscar_specie in enumerate(data['atoms']['specie']): + # if plot_specie == doscar_specie: + # options['plot_indices'].append([k for k in range(data['atoms']['number'][i])]) + + + return None + + + + + +def plot_partial_dos_legacy(data: dict, options={}): required_options = ['atoms', 'orbitals', 'up', 'down', 'sum_atoms', 'collapse_spin', 'sum_orbitals', 'palettes', 'colours', 'fig', 'ax'] @@ -1080,16 +1209,6 @@ def scale_figure(options, width, height): -def swap_values(dict, key1, key2): - - key1_val = dict[key1] - dict[key1] = dict[key2] - dict[key2] = key1_val - - return dict - - - def generate_colours(options: dict): if not isinstance(options['palettes'], list):