# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
.. module:: base_recon
:platform: Unix
:synopsis: A base class for all reconstruction methods
.. moduleauthor:: Mark Basham <scientificsoftware@diamond.ac.uk>
"""
import math
import copy
import numpy as np
import subprocess as sp
import os
np.seterr(divide='ignore', invalid='ignore')
import savu.core.utils as cu
from savu.plugins.plugin import Plugin
MAX_OUTER_PAD = 2.1
[docs]class BaseRecon(Plugin):
def __init__(self, name='BaseRecon'):
super(BaseRecon, self).__init__(name)
self.nOut = 1
self.nIn = 1
self.scan_dim = None
self.rep_dim = None
self.br_vol_shape = None
self.frame_angles = None
self.frame_cors = None
self.projection_shifts = None
self.frame_init_data = None
self.centre = None
self.base_pad_amount = None
self.padding_alg = False
self.cor_shift = 0
self.init_vol = False
self.cor_as_dataset = False
[docs] def base_pre_process(self):
in_data, out_data = self.get_datasets()
in_pData, out_pData = self.get_plugin_datasets()
self.pad_dim = \
in_pData[0].get_data_dimension_by_axis_label('x', contains=True)
in_meta_data = self.get_in_meta_data()[0]
self.exp.log(self.name + " End")
self.br_vol_shape = out_pData[0].get_shape()
self.set_centre_of_rotation(in_data[0], out_data[0], in_meta_data)
self.set_projection_shifts(in_data[0], out_data[0], in_meta_data)
self.main_dir = in_data[0].get_data_patterns()['SINOGRAM']['main_dir']
self.angles = in_meta_data.get('rotation_angle')
if len(self.angles.shape) != 1:
self.scan_dim = in_data[0].get_data_dimension_by_axis_label('scan')
self.slice_dirs = out_data[0].get_slice_dimensions()
shape = in_pData[0].get_shape()
factor = self.__get_outer_pad()
self.sino_pad = int(math.ceil(factor * shape[self.pad_dim]))
self.sino_func, self.cor_func = self.set_function(shape)
self.range = self.parameters['force_zero']
self.fix_sino = self.get_sino_centre_method()
def __get_outer_pad(self):
pad = self.parameters['outer_pad'] if 'outer_pad' in self.parameters \
else False
# length of diagonal of square is side*sqrt(2)
factor = math.sqrt(2) - 1
if isinstance(pad, bool):
return factor if pad is True else 0
factor = float(pad)
if factor > MAX_OUTER_PAD:
factor = MAX_OUTER_PAD
msg = 'Maximum outer_pad value is 2.1, using this instead'
cu.user_message(msg)
return factor
[docs] def get_vol_shape(self):
return self.br_vol_shape
[docs] def set_projection_shifts(self, inData, outData, mData):
# get experimental metadata of projection_shifts
if 'projection_shifts' in list(self.exp.meta_data.dict.keys()):
self.projection_shifts = self.exp.meta_data.dict['projection_shifts']
else:
proj_shifts = np.zeros((inData.get_shape()[0], 2)) # initialise a 2d array of projection shifts
self.exp.meta_data.set('projection_shifts', proj_shifts)
outData.meta_data.set("projection_shifts", copy.deepcopy(self.projection_shifts))
[docs] def set_centre_of_rotation(self, inData, outData, mData):
# if cor has been passed as a dataset then do nothing
if isinstance(self.parameters['centre_of_rotation'], str):
return
if 'centre_of_rotation' in list(mData.get_dictionary().keys()):
cor = self.__set_param_from_meta_data(mData, inData, 'centre_of_rotation')
else:
val = self.parameters['centre_of_rotation']
if isinstance(val, dict):
cor = self.__polyfit_cor(val, inData)
else:
sdirs = inData.get_slice_dimensions()
cor = np.ones(np.prod([inData.get_shape()[i] for i in sdirs]))
# if centre of rotation has not been set then fix it in the
# centre
val = val if val != 0 else \
(self.get_vol_shape()[self._get_detX_dim()]) / 2.0
cor *= val
# mData.set('centre_of_rotation', cor) see Github ticket
self.cor = cor
outData.meta_data.set("centre_of_rotation", copy.deepcopy(self.cor))
self.centre = self.cor[0]
def __set_param_from_meta_data(self, mData, inData, meta_string):
meta_param = mData.get(meta_string)
sdirs = inData.get_slice_dimensions()
total_frames = np.prod([inData.get_shape()[i] for i in sdirs])
if total_frames > len(meta_param):
meta_param = np.tile(meta_param, total_frames // len(meta_param))
return meta_param
def __polyfit_cor(self, cor_dict, inData):
if 'detector_y' in list(inData.meta_data.get_dictionary().keys()):
y = inData.meta_data.get('detector_y')
else:
yDim = inData.get_data_dimension_by_axis_label('detector_y')
y = np.arange(inData.get_shape()[yDim])
z = np.polyfit(list(map(int, list(cor_dict.keys()))), list(cor_dict.values()), 1)
p = np.poly1d(z)
cor = p(y)
return cor
[docs] def set_function(self, pad_shape):
centre_pad = self.parameters['centre_pad'] if 'centre_pad' in \
self.parameters else False
if not centre_pad:
def cor_func(cor):
return cor
if self.parameters['log']:
sino_func = self.__make_lambda()
else:
sino_func = self.__make_lambda(log=False)
else:
def cor_func(cor):
return cor + self.sino_pad
if self.parameters['log']:
sino_func = self.__make_lambda(pad=pad_shape)
else:
sino_func = self.__make_lambda(pad=pad_shape, log=False)
return sino_func, cor_func
def __make_lambda(self, log=True, pad=False):
log_func = 'np.nan_to_num(sino)' if not log else self.parameters['log_func']
if pad:
pad_tuples, mode = self.__get_pad_values(pad)
log_func = log_func.replace(
'sino', 'np.pad(sino, %s, "%s")' % (pad_tuples, mode))
return eval("lambda sino: " + log_func)
def __get_pad_values(self, pad_shape):
mode = 'edge'
pad_tuples = [(0, 0)] * (len(pad_shape) - 1)
pad_tuples.insert(self.pad_dim, (self.sino_pad, self.sino_pad))
pad_tuples = tuple(pad_tuples)
return pad_tuples, mode
[docs] def base_process_frames_before(self, data):
"""
Reconstruct a single sinogram with the provided centre of rotation
"""
sl = self.get_current_slice_list()[0]
init = data[1] if self.init_vol else None
angles = \
self.angles[:, sl[self.scan_dim]] if self.scan_dim else self.angles
angles = np.squeeze(angles)
self.frame_angles = angles
dim_sl = sl[self.main_dir]
if self.cor_as_dataset:
self.frame_cors = self.cor_func(data[len(data) - 1])
else:
frame_nos = \
self.get_plugin_in_datasets()[0].get_current_frame_idx()
a = self.cor[tuple([frame_nos])]
self.frame_cors = self.cor_func(a)
# for extra padded frames that make up the numbers
if not self.frame_cors.shape:
self.frame_cors = np.array([self.centre])
len_data = len(np.arange(dim_sl.start, dim_sl.stop, dim_sl.step))
missing = [self.centre] * (len(self.frame_cors) - len_data)
self.frame_cors = np.append(self.frame_cors, missing)
# fix to remove NaNs in the initialised image
if init is not None:
init[np.isnan(init)] == 0.0
self.frame_init_data = init
data[0] = self.fix_sino(self.sino_func(data[0]), self.frame_cors[0])
return data
[docs] def base_process_frames_after(self, data):
lower_range, upper_range = self.range
data = np.nan_to_num(data)
if lower_range is not None:
data[data < lower_range] = 0.0
if upper_range is not None:
data[data > upper_range] = 0.0
return data
[docs] def pad_sino(self, sino, cor):
""" Pad the sinogram so the centre of rotation is at the centre. """
detX = self._get_detX_dim()
pad = self.get_centre_offset(sino, cor, detX)
self.cor_shift = pad[0]
pad_tuples = [(0, 0)] * (len(sino.shape) - 1)
pad_tuples.insert(detX, tuple(pad))
self.__set_pad_amount(max(pad))
return np.pad(sino, tuple(pad_tuples), mode='edge')
def _get_detX_dim(self):
pData = self.get_plugin_in_datasets()[0]
return pData.get_data_dimension_by_axis_label('x', contains=True)
[docs] def get_centre_offset(self, sino, cor, detX):
centre_pad = self.br_array_pad(cor, sino.shape[detX])
sino_width = sino.shape[detX]
new_width = sino_width + max(centre_pad)
sino_pad = int(math.ceil(float(sino_width) / new_width * self.sino_pad) // 2)
pad = np.array([sino_pad]*2) + centre_pad
return pad
[docs] def get_centre_shift(self, sino, cor):
detX = self._get_detX_dim()
return max(self.get_centre_offset(sino, self.centre, detX))
[docs] def crop_sino(self, sino, cor):
""" Crop the sinogram so the centre of rotation is at the centre. """
detX = self._get_detX_dim()
start, stop = self.br_array_pad(cor, sino.shape[detX])[::-1]
self.cor_shift = -start
sl = [slice(None)] * len(sino.shape)
sl[detX] = slice(start, sino.shape[detX] - stop)
sino = sino[tuple(sl)]
self.set_mask(sino.shape)
return sino
[docs] def br_array_pad(self, ctr, nPixels):
width = nPixels - 1.0
alen = ctr
blen = width - ctr
mid = (width - 1.0) / 2.0
shift = round(abs(blen - alen))
p_low = 0 if (ctr > mid) else shift
p_high = shift + 0 if (ctr > mid) else 0
return np.array([int(p_low), int(p_high)])
[docs] def keep_sino(self, sino, cor):
""" No change to the sinogram """
return sino
[docs] def get_sino_centre_method(self):
centre_pad = self.keep_sino
if 'centre_pad' in list(self.parameters.keys()):
cpad = self.parameters['centre_pad']
if not (cpad is True or cpad is False):
raise Exception('Unknown value for "centre_pad", please choose'
' True or False.')
centre_pad = self.pad_sino if cpad else self.crop_sino
return centre_pad
def __set_pad_amount(self, pad_amount):
self.base_pad_amount = pad_amount
[docs] def get_pad_amount(self):
return self.base_pad_amount
[docs] def get_fov_fraction(self, sino, cor):
""" Get the fraction of the original FOV that can be reconstructed due\
to offset centre """
pData = self.get_plugin_in_datasets()[0]
detX = pData.get_data_dimension_by_axis_label('x', contains=True)
original_length = sino.shape[detX]
shift = self.get_centre_shift(sino, cor)
return (original_length - shift) / float(original_length)
[docs] def get_reconstruction_alg(self):
return None
[docs] def get_angles(self):
""" Get the angles associated with the current sinogram(s).
:returns: Angles of the current frames.
:rtype: np.ndarray
"""
return self.frame_angles
[docs] def get_proj_shifts(self):
""" Get the 2D (X-Y) shifts associated with every projection frame
:returns: projecton shifts for the current frames.
:rtype: np.ndarray
"""
return self.projection_shifts
[docs] def get_cors(self):
"""
Get the centre of rotations associated with the current sinogram(s).
:returns: Centre of rotation values for the current frames.
:rtype: np.ndarray
"""
return self.frame_cors + self.cor_shift
[docs] def set_mask(self, shape):
pass
[docs] def get_initial_data(self):
"""
Get the initial data (if it is exists) associated with the current \
sinogram(s).
:returns: The section of the initialisation data associated with the \
current frames.
:rtype: np.ndarray or None
"""
return self.frame_init_data
[docs] def get_frame_params(self):
params = [self.get_cors(), self.get_angles(), self.get_vol_shape(),
self.get_initial_data()]
return params
[docs] def get_frame_shifts(self):
return self.get_proj_shifts()
[docs] def setup(self):
in_dataset, out_dataset = self.get_datasets()
# reduce the data as per data_subset parameter
self.preview_flag = \
self.set_preview(in_dataset[0], self.parameters['preview'])
self._set_volume_dimensions(in_dataset[0])
axis_labels = self._get_axis_labels(in_dataset[0])
shape = self._get_shape(in_dataset[0])
# output dataset
out_dataset[0].create_dataset(axis_labels=axis_labels, shape=shape)
out_dataset[0].add_volume_patterns(*self._get_volume_dimensions())
# set information relating to the plugin data
in_pData, out_pData = self.get_plugin_datasets()
self.init_vol = 1 if 'init_vol' in list(self.parameters.keys()) and\
self.parameters['init_vol'] else 0
self.cor_as_dataset = 1 if isinstance(
self.parameters['centre_of_rotation'], str) else 0
for i in range(len(in_dataset) - self.init_vol - self.cor_as_dataset):
in_pData[i].plugin_data_setup('SINOGRAM', self.get_max_frames(),
slice_axis=self.get_slice_axis())
idx = 1
# initial volume dataset
if self.init_vol:
# from savu.data.data_structures.data_types import Replicate
# if self.rep_dim:
# in_dataset[idx].data = Replicate(
# in_dataset[idx], in_dataset[0].get_shape(self.rep_dim))
in_pData[1].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
idx += 1
# cor dataset
if self.cor_as_dataset:
self.cor_as_dataset = True
in_pData[idx].plugin_data_setup('METADATA', self.get_max_frames())
# set pattern_name and nframes to process for all datasets
out_pData[0].plugin_data_setup('VOLUME_XZ', self.get_max_frames())
meta_list = ['rotation_angle'] # metadata list to populate
in_meta_data = self.get_in_meta_data()[0]
if 'projection_shifts' in list(self.exp.meta_data.dict.keys()):
self.projection_shifts = self.exp.meta_data.dict['projection_shifts']
else:
self.projection_shifts = np.zeros((in_dataset[0].get_shape()[self.volX], 2)) # initialise a 2d array of projection shifts
self.exp.meta_data.set('projection_shifts', copy.deepcopy(self.projection_shifts))
out_dataset[0].meta_data.set("projection_shifts", copy.deepcopy(self.projection_shifts))
self.populate_metadata_to_output(in_dataset[0], out_dataset[0], in_meta_data, meta_list)
def _get_axis_labels(self, in_dataset):
"""
Get the new axis labels for the output dataset - this is now a volume.
Parameters
----------
in_dataset : :class:`savu.data.data_structures.data.Data`
The input dataset to the plugin.
Returns
-------
labels : dict
The axis labels for the dataset that is output from the plugin.
"""
labels = in_dataset.data_info.get('axis_labels')[0]
volX, volY, volZ = self._get_volume_dimensions()
labels = [str(volX) + '.voxel_x.voxels', str(volZ) + '.voxel_z.voxels']
if volY:
labels.append(str(volY) + '.voxel_y.voxels')
labels = {in_dataset: labels}
return labels
def _set_volume_dimensions(self, data):
"""
Map the input dimensions to the output volume dimensions
Parameters
----------
in_dataset : :class:`savu.data.data_structures.data.Data`
The input dataset to the plugin.
"""
data._finalise_patterns()
self.volX = data.get_data_dimension_by_axis_label("rotation_angle")
self.volZ = data.get_data_dimension_by_axis_label("x", contains=True)
self.volY = data.get_data_dimension_by_axis_label(
"y", contains=True, exists=True)
def _get_volume_dimensions(self):
return self.volX, self.volY, self.volZ
def _get_shape(self, in_dataset):
shape = list(in_dataset.get_shape())
volX, volY, volZ = self._get_volume_dimensions()
if self.parameters['vol_shape'] in ('auto', 'fixed'):
shape[volX] = shape[volZ]
else:
shape[volX] = self.parameters['vol_shape']
shape[volZ] = self.parameters['vol_shape']
if 'resolution' in self.parameters.keys():
shape[volX] = int(shape[volX] // self.parameters['resolution'])
shape[volZ] = int(shape[volZ] // self.parameters['resolution'])
return tuple(shape)
[docs] def get_max_frames(self):
"""
Number of data frames to pass to each instance of process_frames func
Returns
-------
str or int
"single", "multiple" or integer (only to be used if the number of
frames MUST be fixed.)
"""
if self._max_frames:
frames_max = self._max_frames
else:
frames_max = 'multiple'
return frames_max
return 'multiple'
[docs] def get_slice_axis(self):
"""
Fix the fastest changing slice dimension
Returns
-------
str or None
str should be the axis_label corresponding to the fastest changing
dimension
"""
return None
[docs] def nOutput_datasets(self):
return self.nOut
[docs] def reconstruct_pre_process(self):
"""
Should be overridden to perform pre-processing in a child class
"""
pass
[docs] def executive_summary(self):
summary = []
if not self.preview_flag:
summary.append(("WARNING: Ignoring preview parameters as a preview"
" has already been applied to the data."))
if len(summary) > 0:
return summary
return ["Nothing to Report"]
[docs] def get_gpu_memory(self):
command = "nvidia-smi --query-gpu=memory.free --format=csv"
memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
return memory_free_values
def _handle_case_1(self, in_text):
final_idx = eval(in_text)
return final_idx
def _handle_case_2(self, in_text):
pos1 = [idx - 2 for idx, char in enumerate(in_text) if char == "."]
pos2 = [idx for idx, char in enumerate(in_text) if char == ")"]
list_idx1 = []
for i, pos in enumerate(pos1):
cmd = in_text[pos:pos2[i] + 1]
list_idx1.extend(eval(cmd))
for i, pos in enumerate(pos1):
tmp = "-1"
for j in range(2, 1 + pos2[i] - pos):
tmp = tmp + " "
in_text = in_text.replace(in_text[pos:pos2[i] + 1], tmp)
text = in_text.replace("[", "")
text = text.replace("]", "")
out_text = text.split(",")
list_idx = [int(x) for x in out_text]
list_idx = np.asarray(list_idx)
list_idx = list_idx[list_idx != -1]
final_idx = np.sort(np.concatenate((np.asarray(list_idx1), list_idx)))
return np.unique(np.int16(final_idx))
def _handle_case_3(self, in_text):
text = in_text.replace("[", "")
text = text.replace("]", "")
out_text = text.split(",")
final_idx = []
for x in out_text:
try:
num = int(x)
final_idx.append(num)
except ValueError:
pass
if final_idx:
final_idx = np.unique(np.sort(final_idx))
else:
final_idx = None
return final_idx
[docs] def get_skipping_indices(self, in_text):
if isinstance(in_text, str):
if "np." in in_text:
final_idx = self._handle_case_1(in_text)
if not isinstance(final_idx, np.ndarray):
final_idx = self._handle_case_2(in_text)
else:
final_idx = np.unique(np.sort(np.ndarray.flatten(final_idx)))
else:
final_idx = self._handle_case_3(in_text)
if not isinstance(final_idx, np.ndarray):
raise ValueError("Incorrect syntax!!!")
elif isinstance(in_text, list):
final_idx = np.unique(np.sort(np.asarray(in_text)))
else:
final_idx = None
return final_idx