Source code for plugins.reconstructions.astra_recons.base_astra_recon

# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: base_astra_recon
   :platform: Unix
   :synopsis: A base for all Astra toolbox reconstruction algorithms
.. moduleauthor:: Mark Basham <scientificsoftware@diamond.ac.uk>
"""

import astra
import numpy as np

from savu.plugins.reconstructions.base_recon import BaseRecon
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \
    check_if_end_plugin_in_iterate_group

[docs]class BaseAstraRecon(BaseRecon): def __init__(self, name='BaseAstraRecon'): super(BaseAstraRecon, self).__init__(name) self.res = False # total number of output datasets
[docs] def nOutput_datasets(self): if check_if_end_plugin_in_iterate_group(self.exp): return 2 else: return 1
# total number of output datasets that are clones
[docs] def nClone_datasets(self): if check_if_end_plugin_in_iterate_group(self.exp): return 1 else: return 0
@enable_iterative_loop def setup(self): self.alg = self.parameters['algorithm'] self.get_max_frames = self._get_multiple if '3D' in self.alg else self._get_single super(BaseAstraRecon, self).setup() out_dataset = self.get_out_datasets() # if res_norm is required then setup another output dataset if len(out_dataset) == 3 and self.nClone_datasets() == 1: err_str = "The res_norm output dataset has not yet been " \ "implemented for when AstraReconCpu is at the end of an " \ "iterative loop" raise ValueError(err_str) elif len(out_dataset) == 2 and self.nClone_datasets() == 0: self.res = True out_pData = self.get_plugin_out_datasets() in_data = self.get_in_datasets()[0] dim_detX = \ in_data.get_data_dimension_by_axis_label('y', contains=True) nIts = self.parameters['n_iterations'] nIts = nIts if isinstance(nIts, list) else [nIts] self.len_res = max(nIts) shape = (in_data.get_shape()[dim_detX], max(nIts)) label = ['vol_y.voxel', 'iteration.number'] pattern = {'name': 'SINOGRAM', 'slice_dims': (0,), 'core_dims': (1,)} out_dataset[1].create_dataset(axis_labels=label, shape=shape) out_dataset[1].add_pattern(pattern['name'], slice_dims=pattern['slice_dims'], core_dims=pattern['core_dims']) out_pData[1].plugin_data_setup( pattern['name'], self.get_max_frames())
[docs] def pre_process(self): self.alg = self.parameters['algorithm'] self.iters = self.parameters['n_iterations'] if '3D' in self.alg: self.setup_3D() self.process_frames = self.astra_3D_recon else: self.setup_2D() self.process_frames = self.astra_2D_recon
[docs] def setup_2D(self): pData = self.get_plugin_in_datasets()[0] self.dim_detX = \ pData.get_data_dimension_by_axis_label('x', contains=True) self.dim_rot = \ pData.get_data_dimension_by_axis_label('rot', contains=True) self.sino_shape = pData.get_shape() self.nDims = len(self.sino_shape) self.nCols = self.sino_shape[self.dim_detX] self.set_mask(self.sino_shape)
[docs] def set_mask(self, shape): l = self.get_plugin_out_datasets()[0].get_shape()[0] c = np.linspace(-l / 2.0, l / 2.0, l) x, y = np.meshgrid(c, c) ratio = self.parameters['ratio'] if isinstance(ratio, list) or isinstance(ratio, tuple): ratio_mask = ratio[0] outer_mask = ratio[1] if isinstance(outer_mask, str): outer_mask = np.nan else: ratio_mask = ratio outer_mask = np.nan r = (l - 1) * ratio_mask outer_pad = True if self.parameters['outer_pad'] and self.padding_alg\ else False if not outer_pad: self.manual_mask = \ np.array((x**2 + y**2 < (r / 2.0)**2), dtype=np.float) self.manual_mask[self.manual_mask == 0] = outer_mask else: self.manual_mask = False
[docs] def astra_2D_recon(self, data): sino = data[0] cor, angles, vol_shape, init = self.get_frame_params() skip = self.parameters['skip_projections'] skip_idx = self.get_skipping_indices(skip) if skip_idx is not None: max_idx = sino.shape[0] skip_idx = np.unique(np.clip(skip_idx, 0, max_idx - 1)) use_idx = np.setdiff1d(np.arange(max_idx), skip_idx) sino = sino[use_idx] angles = angles[use_idx] angles = np.deg2rad(angles) if self.res: res = np.zeros(self.len_res) # create volume geom vol_geom = astra.create_vol_geom(vol_shape) # create projection geom det_width = sino.shape[self.dim_detX] proj_geom = astra.create_proj_geom('parallel', 1.0, det_width, angles) sino = np.transpose(sino, (self.dim_rot, self.dim_detX)) # create sinogram id sino_id = astra.data2d.create("-sino", proj_geom, sino) # create reconstruction id if init is not None: rec_id = astra.data2d.create('-vol', vol_geom, init) else: rec_id = astra.data2d.create('-vol', vol_geom) # if self.mask_id: # self.mask_id = astra.data2d.create('-vol', vol_geom, self.mask) # setup configuration options cfg = self.set_config(rec_id, sino_id, proj_geom, vol_geom) # create algorithm id alg_id = astra.algorithm.create(cfg) # run algorithm if self.res: for j in range(self.iters): # Run a single iteration astra.algorithm.run(alg_id, 1) res[j] = astra.algorithm.get_res_norm(alg_id) else: astra.algorithm.run(alg_id, self.iters) # get reconstruction matrix if self.manual_mask is not False: recon = self.manual_mask * astra.data2d.get(rec_id) else: recon = astra.data2d.get(rec_id) # delete geometry self.delete(alg_id, sino_id, rec_id, False) return [recon, res] if self.res else recon
[docs] def set_config(self, rec_id, sino_id, proj_geom, vol_geom): cfg = astra.astra_dict(self.alg) cfg['ReconstructionDataId'] = rec_id cfg['ProjectionDataId'] = sino_id if 'FBP' in self.alg: fbp_filter = self.parameters['FBP_filter'] if 'FBP_filter' in \ list(self.parameters.keys()) else 'none' cfg['FilterType'] = fbp_filter if 'projector' in list(self.parameters.keys()): proj_id = astra.create_projector( self.parameters['projector'], proj_geom, vol_geom) cfg['ProjectorId'] = proj_id cfg = self.set_options(cfg) return cfg
[docs] def delete(self, alg_id, sino_id, rec_id, proj_id): astra.algorithm.delete(alg_id) astra.data2d.delete(sino_id) astra.data2d.delete(rec_id) if proj_id: astra.projector.delete(proj_id)
def _get_single(self): return 'single' def _get_multiple(self): return 'multiple'