Source code for plugins.driver.gpu_plugin

# Copyright 2014 Diamond Light Source Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
.. module:: gpu_plugin
   :platform: Unix
   :synopsis: The driver for GPU plugins.

.. moduleauthor:: Mark Basham <scientificsoftware@diamond.ac.uk>

"""

import os
import copy
import logging
import numpy as np
import pynvml as pv
from mpi4py import MPI
from itertools import chain

from savu.plugins.driver.plugin_driver import PluginDriver
from savu.plugins.driver.basic_driver import BasicDriver

_base = BasicDriver if os.environ['savu_mode'] == 'basic' else PluginDriver


[docs]class GpuPlugin(_base): def __init__(self): super(GpuPlugin, self).__init__() def _run_plugin(self, exp, transport): expInfo = exp.meta_data processes = copy.copy(expInfo.get("processes")) process = expInfo.get("process") gpu_processes = [False] * len(processes) idx = [i for i in range(len(processes)) if 'GPU' in processes[i]] for i in idx: gpu_processes[i] = True # set only GPU processes new_processes = [i for i in processes if 'GPU' in i] if not new_processes: raise Exception("THERE ARE NO GPU PROCESSES!") expInfo.set('processes', new_processes) nNodes = new_processes.count(new_processes[0]) ranks = [i for i, x in enumerate(gpu_processes) if x] idx = [i for i in range(len(ranks)) if new_processes[i] == 'GPU0'] diff = np.diff(np.array(idx)) if len(idx) > 1 else 1 split = np.max(diff) if not isinstance(diff, int) else len(ranks) split_ranks = [ranks[n:n + split] for n in range(0, len(ranks), split)] ranks = list(chain.from_iterable(zip(*split_ranks))) self.__create_new_communicator(ranks, exp, process) if gpu_processes[process]: self.stats_obj.GPU = True expInfo.set('process', self.new_comm.Get_rank()) GPU_index = self.__calculate_GPU_index(nNodes) logging.debug("Running the GPU process %i with GPU index %i", self.new_comm.Get_rank(), GPU_index) self.parameters['GPU_index'] = GPU_index os.environ['CUDA_DEVICE'] = str(GPU_index) self._run_plugin_instances(transport, communicator=self.new_comm) self.__free_communicator() expInfo.set('process', MPI.COMM_WORLD.Get_rank()) else: logging.info('Not a GPU process: Waiting...') if self.stats_obj.calc_stats: self.stats_obj._broadcast_gpu_stats(gpu_processes, process) self.exp._barrier() expInfo.set('processes', processes) return def __create_new_communicator(self, ranks, exp, process): self.group = MPI.COMM_WORLD.Get_group() self.new_group = MPI.Group.Incl(self.group, ranks) self.new_comm = MPI.COMM_WORLD.Create(self.new_group) self.exp._barrier() def __free_communicator(self): self.group.Free() self.new_group.Free() self.new_comm.Free() def __calculate_GPU_index(self, nNodes): pv.nvmlInit() nGPUs = int(pv.nvmlDeviceGetCount()) rank = self.new_comm.Get_rank() return int(rank / nNodes) % nGPUs