1 Star 0 Fork 0

尧小飞 / cellphonedb

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
run_cellphonedb_rabbitmq.py 16.74 KB
一键复制 编辑 原始数据 按行查看 历史
#!/usr/bin/env python
import io
import json
import os
import sys
import tempfile
import time
import traceback
from distutils.util import strtobool
from functools import wraps
from logging import INFO
from typing import Callable
import boto3
import pandas as pd
import pika
from cellphonedb.src.app import cpdb_app
from cellphonedb.src.core.exceptions.AllCountsFilteredException import AllCountsFilteredException
from cellphonedb.src.core.exceptions.EmptyResultException import EmptyResultException
from cellphonedb.src.core.exceptions.ThresholdValueException import ThresholdValueException
from cellphonedb.src.core.utils.subsampler import Subsampler
from cellphonedb.src.database.manager.DatabaseVersionManager import list_local_versions, find_database_for
from cellphonedb.src.exceptions.ParseCountsException import ParseCountsException
from cellphonedb.src.exceptions.ParseMetaException import ParseMetaException
from cellphonedb.src.exceptions.PlotException import PlotException
from cellphonedb.src.exceptions.ReadFileException import ReadFileException
from cellphonedb.src.plotters.r_plotter import dot_plot, heatmaps_plot
from cellphonedb.utils import utils
from rabbit_logger import RabbitAdapter, RabbitLogger
rabbit_logger = RabbitLogger()
try:
s3_access_key = os.environ['S3_ACCESS_KEY']
s3_secret_key = os.environ['S3_SECRET_KEY']
s3_bucket_name = os.environ['S3_BUCKET_NAME']
s3_endpoint = os.environ['S3_ENDPOINT']
rabbit_host = os.environ['RABBIT_HOST']
rabbit_port = os.environ['RABBIT_PORT']
rabbit_user = os.environ['RABBIT_USER']
rabbit_password = os.environ['RABBIT_PASSWORD']
jobs_queue_name = os.environ['RABBIT_JOB_QUEUE']
result_queue_name = os.environ['RABBIT_RESULT_QUEUE']
queue_type = os.environ['QUEUE_TYPE']
except KeyError as e:
rabbit_logger.error('ENVIRONMENT VARIABLE {} not defined. Please set it'.format(e))
exit(1)
verbose = bool(strtobool(os.getenv('VERBOSE', 'true')))
if verbose:
rabbit_logger.setLevel(INFO)
def logger_for_job(job_id):
return RabbitAdapter.logger_for(rabbit_logger, job_id)
def _track_success(f) -> Callable:
@wraps(f)
def wrapper(*args, **kwargs):
logger = kwargs.get('logger', rabbit_logger)
logger.info('calling {} method'.format(f.__name__))
result = f(*args, **kwargs)
logger.info('successfully called {} method'.format(f.__name__))
return result
return wrapper
def create_rabbit_connection():
return pika.BlockingConnection(pika.ConnectionParameters(
host=rabbit_host,
port=rabbit_port,
virtual_host='/',
credentials=credentials
))
s3_resource = boto3.resource('s3', aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint)
s3_client = boto3.client('s3', aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint)
def read_data_from_s3(filename: str, s3_bucket_name: str, index_column_first: bool):
s3_object = s3_client.get_object(Bucket=s3_bucket_name, Key=filename)
return utils.read_data_from_s3_object(s3_object, filename, index_column_first=index_column_first)
def write_data_in_s3(data: pd.DataFrame, filename: str):
result_buffer = io.StringIO()
data.to_csv(result_buffer, index=False, sep='\t')
result_buffer.seek(0)
# TODO: Find more elegant solution (connexion closes after timeout)
s3_client = boto3.client('s3', aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint)
s3_client.put_object(Body=result_buffer.getvalue().encode('utf-8'), Bucket=s3_bucket_name, Key=filename)
def write_image_to_s3(path: str, filename: str):
_io = open(path, 'rb')
# TODO: Find more elegant solution (connexion closes after timeout)
s3_client = boto3.client('s3', aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint)
s3_client.put_object(Body=_io, Bucket=s3_bucket_name, Key=filename)
@_track_success
def dot_plot_results(means: str, pvalues: str, rows: str, columns: str, job_id: str):
with tempfile.TemporaryDirectory() as output_path:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(means)[-1]) as means_file:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(pvalues)[-1]) as pvalues_file:
with tempfile.NamedTemporaryFile() as rows_file:
with tempfile.NamedTemporaryFile() as columns_file:
_from_s3_to_temp(means, means_file)
_from_s3_to_temp(pvalues, pvalues_file)
_from_s3_to_temp(rows, rows_file)
_from_s3_to_temp(columns, columns_file)
output_name = 'plot__{}.png'.format(job_id)
dot_plot(means_path=means_file.name,
pvalues_path=pvalues_file.name,
output_path=output_path,
output_name=output_name,
rows=rows_file.name,
columns=columns_file.name)
output_file = os.path.join(output_path, output_name)
if not os.path.exists(output_file):
raise PlotException('Could not generate output file for plot of type dot_plot')
response = {
'job_id': job_id,
'files': {
'plot': output_name,
},
'success': True
}
write_image_to_s3(output_file, output_name)
return response
@_track_success
def heatmaps_plot_results(meta: str, pvalues: str, pvalue: float, job_id: str):
with tempfile.TemporaryDirectory() as output_path:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(pvalues)[-1]) as pvalues_file:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(meta)[-1]) as meta_file:
_from_s3_to_temp(pvalues, pvalues_file)
_from_s3_to_temp(meta, meta_file)
count_name = 'plot_count__{}.png'.format(job_id)
count_log_name = 'plot_count_log__{}.png'.format(job_id)
count_network_name = 'count_network__{}.txt'.format(job_id)
interactions_count_name = 'interactions_count__{}.txt'.format(job_id)
heatmaps_plot(meta_file=meta_file.name,
pvalues_file=pvalues_file.name,
output_path=output_path,
count_name=count_name,
log_name=count_log_name,
count_network_filename=count_network_name,
interaction_count_filename=interactions_count_name,
pvalue=pvalue)
output_count_file = os.path.join(output_path, count_name)
output_count_log_file = os.path.join(output_path, count_log_name)
output_count_network_file = os.path.join(output_path, count_network_name)
output_interactions_count_file = os.path.join(output_path, interactions_count_name)
if not os.path.exists(output_count_file) \
or not os.path.exists(output_count_log_file) \
or not os.path.exists(output_count_network_file) \
or not os.path.exists(output_interactions_count_file):
raise PlotException('Could not generate output file for plot of type heatmap_plot')
response = {
'job_id': job_id,
'files': {
'count_plot': count_name,
'count_log_plot': count_log_name,
'count_network': count_network_name,
'interactions_sum': interactions_count_name,
},
'success': True
}
write_image_to_s3(output_count_file, count_name)
write_image_to_s3(output_count_log_file, count_log_name)
write_image_to_s3(output_count_network_file, count_network_name)
write_image_to_s3(output_interactions_count_file, interactions_count_name)
return response
def _from_s3_to_temp(key, file):
data = s3_client.get_object(Bucket=s3_bucket_name, Key=key)
file.write(data['Body'].read())
file.seek(0)
return file
@_track_success
def process_plot(method, properties, body, logger) -> dict:
metadata = json.loads(body.decode('utf-8'))
job_id = metadata['job_id']
logger.info('New Plot Queued')
plot_type = metadata.get('type', None)
if plot_type == 'dot_plot':
return dot_plot_results(metadata.get('file_means'),
metadata.get('file_pvalues'),
metadata.get('file_rows', None),
metadata.get('file_columns', None),
job_id
)
if plot_type == 'heatmaps_plot':
return heatmaps_plot_results(metadata.get('file_meta'),
metadata.get('file_pvalues'),
metadata.get('pvalue', 0.05),
job_id
)
return {
'job_id': job_id,
'success': False,
'error': {
'id': 'UnknownPlotType',
'message': 'Given plot type does not exist: {}'.format(plot_type)
}
}
@_track_success
def process_method(method, properties, body, logger) -> dict:
metadata = json.loads(body.decode('utf-8'))
job_id = metadata['job_id']
logger.info('New Job Queued')
meta = read_data_from_s3(metadata['file_meta'], s3_bucket_name, index_column_first=False)
counts = read_data_from_s3(metadata['file_counts'], s3_bucket_name, index_column_first=True)
subsampler = Subsampler(bool(metadata['log']),
int(metadata['num_pc']),
int(metadata['num_cells']) if metadata.get('num_cells', False) else None
) if metadata.get('subsampling', False) else None
database_version = metadata.get('database_version', 'latest')
if database_version not in list_local_versions() + ['latest']:
database_version = 'latest'
app = cpdb_app.create_app(verbose=verbose, database_file=find_database_for(database_version))
if metadata['iterations']:
response = statistical_analysis(app, meta, counts, job_id, metadata, subsampler)
else:
response = non_statistical_analysis(app, meta, counts, job_id, metadata, subsampler)
return response
@_track_success
def statistical_analysis(app, meta, counts, job_id, metadata, subsampler):
pvalues, means, significant_means, deconvoluted = \
app.method.cpdb_statistical_analysis_launcher(meta,
counts,
counts_data=metadata.get('counts_data', 'ensembl'),
threshold=float(metadata['threshold']),
iterations=int(metadata['iterations']),
debug_seed=-1,
threads=4,
result_precision=int(metadata['result_precision']),
pvalue=float(metadata.get('pvalue', 0.05)),
subsampler=subsampler,
)
response = {
'job_id': job_id,
'files': {
'pvalues': 'pvalues_simple_{}.txt'.format(job_id),
'means': 'means_simple_{}.txt'.format(job_id),
'significant_means': 'significant_means_simple_{}.txt'.format(job_id),
'deconvoluted': 'deconvoluted_simple_{}.txt'.format(job_id),
},
'success': True
}
write_data_in_s3(pvalues, response['files']['pvalues'])
write_data_in_s3(means, response['files']['means'])
write_data_in_s3(significant_means, response['files']['significant_means'])
write_data_in_s3(deconvoluted, response['files']['deconvoluted'])
return response
@_track_success
def non_statistical_analysis(app, meta, counts, job_id, metadata, subsampler):
means, significant_means, deconvoluted = \
app.method.cpdb_method_analysis_launcher(meta,
counts,
counts_data=metadata.get('counts_data', 'ensembl'),
threshold=float(metadata['threshold']),
result_precision=int(metadata['result_precision']),
subsampler=subsampler,
)
response = {
'job_id': job_id,
'files': {
'means': 'means_simple_{}.txt'.format(job_id),
'significant_means': 'significant_means_{}.txt'.format(job_id),
'deconvoluted': 'deconvoluted_simple_{}.txt'.format(job_id),
},
'success': True
}
write_data_in_s3(means, response['files']['means'])
write_data_in_s3(significant_means, response['files']['significant_means'])
write_data_in_s3(deconvoluted, response['files']['deconvoluted'])
return response
consume_more_jobs = True
credentials = pika.PlainCredentials(rabbit_user, rabbit_password)
connection = create_rabbit_connection()
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
jobs_runned = 0
while jobs_runned < 3 and consume_more_jobs:
job = channel.basic_get(queue=jobs_queue_name, no_ack=True)
if all(job):
job_id = json.loads(job[2].decode('utf-8'))['job_id']
job_logger = logger_for_job(job_id)
try:
if queue_type == 'plot':
job_response = process_plot(*job, logger=job_logger)
elif queue_type == 'method':
job_response = process_method(*job, logger=job_logger)
else:
raise Exception('Unknown queue type')
# TODO: Find more elegant solution
connection = create_rabbit_connection()
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
channel.basic_publish(exchange='', routing_key=result_queue_name, body=json.dumps(job_response))
job_logger.info('JOB PROCESSED')
except (ReadFileException, ParseMetaException, ParseCountsException, ThresholdValueException,
AllCountsFilteredException, EmptyResultException, PlotException) as e:
error_response = {
'job_id': job_id,
'success': False,
'error': {
'id': str(e),
'message': (' {}.'.format(e.description) if hasattr(e, 'description') and e.description else '') +
(' {}.'.format(e.hint) if hasattr(e, 'hint') and e.hint else '')
}
}
print(traceback.print_exc(file=sys.stdout))
job_logger.error('[-] ERROR DURING PROCESSING JOB')
if connection.is_closed:
connection = create_rabbit_connection()
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
channel.basic_publish(exchange='', routing_key=result_queue_name, body=json.dumps(error_response))
job_logger.error(e)
except Exception as e:
error_response = {
'job_id': job_id,
'success': False,
'error': {
'id': 'unknown_error',
'message': ''
}
}
print(traceback.print_exc(file=sys.stdout))
job_logger.error('[-] ERROR DURING PROCESSING JOB')
if connection.is_closed:
connection = create_rabbit_connection()
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
channel.basic_publish(exchange='', routing_key=result_queue_name, body=json.dumps(error_response))
job_logger.error(e)
jobs_runned += 1
else:
rabbit_logger.debug('Empty queue')
time.sleep(1)
1
https://gitee.com/yao_xiao_fei2/cellphonedb.git
git@gitee.com:yao_xiao_fei2/cellphonedb.git
yao_xiao_fei2
cellphonedb
cellphonedb
master

搜索帮助