Source code for wpg.wavefront

# -*- coding: utf-8 -*-
"""
This module contains base wrapper for SRWLWfr (Wavefront). It's implement numpy inter operations to SRWLWfr structure, serialization to HDF5, visualization tools, etc.

.. module:: wpg.wavefront
   :platform: Linux, Mac OSX, Windows

.. moduleauthor:: Alexey Buzmakov <buzmakov@gmail.com>
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import array
import warnings

import numpy as np
import h5py

import wpg.srwlib as srwlib

import wpg.utils as utils
import wpg.glossary as glossary

from wpg.utils import srw_obj2str

warnings.filterwarnings('ignore', category=Warning)


[docs]class Wavefront(object): """ This is base class for manipulation with wavefronts in wpg module. One of most important field is _srwl_wf (instance of srwlib.SRWLWfr). SEtting and getting this field allows to call all SRWLpy functions. """ def __init__(self, srwl_wavefront=None): """ Create wavefront instance. The most important wavefront fields dynamically initialize from :mod:`wpg.glossry` :param srwl_wavefront: if present, wavefront inits with it's parameters :type srwl_wavefront: srwlib.SRWLWfr :return: Wavefront instance. """ if srwl_wavefront is None: self._srwl_wf = srwlib.SRWLWfr() else: self._srwl_wf = srwl_wavefront self._wf_fields = {} self.custom_fields = {} for wf_field in glossary.get_wf_fields(): wf = wf_field(self) self._add_field(wf) def _get_total_elements(self): """ Get total amount of points in wavefront. :return: total amount of points in wavefront """ return self.params.Mesh.nx * self.params.Mesh.ny * self.params.Mesh.nSlices def _allocate_srw_moments(self): """Allocate memory for SRW structures.""" self._srwl_wf.arMomX = array.array( str(u'd'), [0] * self.params.Mesh.nSlices * 11) self._srwl_wf.arMomY = array.array( str(u'd'), [0] * self.params.Mesh.nSlices * 11) def _add_field(self, wf_field): """ Add field to wavefront structure and create field. :param wf_field: field instance :type wf_field: wpg.glossary.RadiationField """ class glossary_folder(object): """Glossary folder. Empty class to build dictionary tree.""" pass def get_value(self): """Get value stored in field.""" return wf_field.value def set_value(self, value): """ Get value stored in field. :param value: value to be stored """ wf_field.value = value def get_doc(): """Get field documentation string.""" return wf_field.value.__doc__ if not isinstance(wf_field, glossary.RadiationField): raise TypeError('wf_field must be RadiationField') self._wf_fields[wf_field.glossary_name] = wf_field node = self keys_chain = wf_field.keys_chain for key in keys_chain[:-1]: if key not in node.__dict__: node.__dict__[key] = glossary_folder() node = node.__dict__[key] setattr(node.__class__, keys_chain[-1], property(get_value, set_value, doc=get_doc())) def _to_dict(self): """ Convert wavefront to dictionary. Used for saving in HDF5 file. :return: dictionary view of wavefront """ res = {} for (key, value) in self._wf_fields.items(): res[key] = value.value res.update(self.custom_fields) return res def _update_from_dict(self, in_dict): """ Update wavefront from dictionary. Used for loading wavefront to HDF5 file. :param in_dict: input dictionary :type in_dict: dict """ for (key, value) in in_dict.items(): # python3 hack if isinstance(key, bytes): key = key.decode('utf-8') if key in self._wf_fields: self._wf_fields[key].value = value else: utils.update_dict_slash_string(self.custom_fields, key, value) def _store_attributes(self, file_name): """ Store wavefront attributes to HDF5 file. Attribute of each field is values of field.attributes :param file_name: output HDF5 file name :type file_name: string """ with h5py.File(file_name) as h5f: for (key, wff) in self._wf_fields.items(): try: if wff.glossary_name in h5f: for (k, v) in list(wff.attributes.items()): h5f[wff.glossary_name].attrs[k] = v except KeyError: pass
[docs] def store_hdf5(self, file_name): """ Store wavefront to HDF5 file (attributes and values). :param file_name: output HDF5 file name :type file_name: string """ utils.store_dict_hdf5(file_name, self._to_dict()) self._store_attributes(file_name)
[docs] def load_hdf5(self, file_name): """ Load wavefront from HDF5 file. :param file_name: output HDF5 file name :type file_name: string """ self._update_from_dict(utils.load_dict_slash_hdf5(file_name))
[docs] def get_intensity(self, slice_number=None, polarization=None): """ Return intensity of wavefront :param polarization: 'total' or 'horizontal' or 'vertical' :type polarization: string :param slice_number: slice number ti return, if None - get 3D array (all slices) :type slice_number: int or range :return: array of intensities """ if polarization == 'total' or (polarization is None): pol = 6 elif polarization == 'horizontal': pol = 0 elif polarization == 'vertical': pol = 1 else: raise ValueError( 'unknown polarization value, should be "total" or "horizontal" or "vertical"') res = np.zeros(self._get_total_elements(), dtype='float32') if not res.flags['C_CONTIGUOUS']: res = np.ascontiguousarray(res) res = srwlib.srwl.CalcIntFromElecField( res, self._srwl_wf, pol, 0, 6, self.params.photonEnergy, 0, 0) res = np.array(res, dtype='float32', copy=False) res.shape = ( self.params.Mesh.ny, self.params.Mesh.nx, self.params.Mesh.nSlices) if slice_number is not None: res = res[:, :, slice_number] return res
[docs] def get_phase(self, slice_number=None, polarization=None): """ Return phase of wavefront. :param polarization: 'total' or 'horizontal' or 'vertical' :type polarization: string :param slice_number: slice number ti return, if None - get 3D array (all slices) :type slice_number: int or range :return: array of phases """ # TODO: bug with freeze if polarization == 'total' or (polarization is None): pol = 6 print( 'Attention!!! The "total" polarization behavior sometimes strange. Use "horizontal" or "vertical".' ) elif polarization == 'horizontal': pol = 0 elif polarization == 'vertical': pol = 1 else: raise ValueError( 'unknown polarization value, should be "total" or "horizontal" or "vertical"') res = np.arctan2(self.get_imag_part(slice_number=slice_number, polarization=polarization), self.get_real_part(slice_number=slice_number, polarization=polarization)) # res = array.array('f',[0]*self.get_total_elements()) # res = srwlib.srwl.CalcIntFromElecField(res, self._srwl_wf, pol, 0, 6, self.params.photonEnergy, 0, 0.) # res = np.array(res, dtype='float32') # res.shape = (self.params.Mesh.ny,self.params.Mesh.nx,self.params.Mesh.nSlices) # if not slice is None: # res = res[:, :, slice] return res
[docs] def get_real_part(self, slice_number=None, polarization=None): """ Return real part of wavefront. :param polarization: 'total' or 'horizontal' or 'vertical' :type polarization: string :param slice_number: slice number ti return, if None - get 3D array (all slices) :type slice_number: int or range :return: array of real parts """ if polarization == 'total' or (polarization is None): pol = 6 elif polarization == 'horizontal': pol = 0 elif polarization == 'vertical': pol = 1 else: raise ValueError( 'unknown polarization value, should be "total" or "horizontal" or "vertical"') res = np.zeros(self._get_total_elements(), dtype='float32') if not res.flags['C_CONTIGUOUS']: res = np.ascontiguousarray(res) res = srwlib.srwl.CalcIntFromElecField( res, self._srwl_wf, pol, 5, 6, self.params.photonEnergy, 0, 0) res = np.array(res, dtype='float32', copy=False) res.shape = ( self.params.Mesh.ny, self.params.Mesh.nx, self.params.Mesh.nSlices) if slice_number is not None: res = res[:, :, slice_number] return res
[docs] def get_imag_part(self, slice_number=None, polarization=None): """ Return imaginary part of wavefront. :param polarization: 'total' or 'horizontal' or 'vertical' :type polarization: string :param slice_number: slice number ti return, if None - get 3D array (all slices) :type slice_number: int or range :return: array of imaginary parts """ if polarization == 'total' or (polarization is None): pol = 6 elif polarization == 'horizontal': pol = 0 elif polarization == 'vertical': pol = 1 else: raise ValueError( 'unknown polarization value, should be "total" or "horizontal" or "vertical"') res = np.zeros(self._get_total_elements(), dtype='float32') if not res.flags['C_CONTIGUOUS']: res = np.ascontiguousarray(res) res = srwlib.srwl.CalcIntFromElecField( res, self._srwl_wf, pol, 6, 6, self.params.photonEnergy, 0, 0) res = np.array(res, dtype='float32', copy=False) res.shape = ( self.params.Mesh.ny, self.params.Mesh.nx, self.params.Mesh.nSlices) if slice_number is not None: res = res[:, :, slice_number] return res
[docs] def get_limits(self, axis='z'): """ Get wavefront mesh limits [xmin, xmax, ....]. Used in 2D visualization tools (as pylab.imshow(wfr_data, extends=wrf.get_limits())) :params axis: 'x','y' or 'z' :type axis: string :return: list of integers """ sr = self.params.Mesh rep = self.params.wSpace if rep == 'R-space': print(rep) if axis == 'z': return sr.xMin, sr.xMax, sr.yMax, sr.yMin elif axis == 'x': return sr.sliceMin, sr.sliceMax, sr.yMax, sr.yMin elif axis == 'y': return sr.sliceMin, sr.sliceMax, sr.xMax, sr.xMin elif rep == 'Q-space': print(rep) wl = 12.398 * 1e-10 / (self.params.photonEnergy * 1e-3) # WaveLength # wv = 2.*np.pi/wl # #WaveVector if axis == 'z': return sr.qxMin * wl, sr.qxMax * wl, sr.qyMax * wl, sr.qyMin * wl elif axis == 'x': return sr.sliceMin, sr.sliceMax, sr.qyMax * wl, sr.qyMin * wl elif axis == 'y': return sr.sliceMin, sr.sliceMax, sr.qxMax * wl, sr.qxMin * wl
def __str__(self): """ String representation to enable print function. :return: String representation """ mesh_str = 'Mesh:\n\t\t' + \ '\n\t\t'.join(srw_obj2str(self.params.Mesh).split('\n')) + '\n\t' radiation_str = mesh_str + \ '\n\t'.join(srw_obj2str(self.params).split('\n')) radiation_str = 'Radiation:\n\t' + radiation_str + '\n' data_ehor = '\tarrEhor = array of shape ' + \ str(self.data.arrEhor.shape) + \ ' // the 2-nd dimension is (re,im)\n' data_ever = '\tarrEver = array of shape ' + \ str(self.data.arrEver.shape) + \ ' // the 2-nd dimension is (re,im)\n' data_str = data_ehor + data_ever return radiation_str + data_str
[docs] def srw_info(self): """ Print self._srwl_wf string representation. Used for debugging. :return: string """ return srw_obj2str(self._srwl_wf)