代码拉取完成,页面将自动刷新
#!/usr/bin/env python
import io
import json
import os
import sys
import tempfile
import time
import traceback
from distutils.util import strtobool
from functools import wraps
from logging import INFO
from typing import Callable
import boto3
import pandas as pd
import pika
from cellphonedb.src.app import cpdb_app
from cellphonedb.src.core.exceptions.AllCountsFilteredException import AllCountsFilteredException
from cellphonedb.src.core.exceptions.EmptyResultException import EmptyResultException
from cellphonedb.src.core.exceptions.ThresholdValueException import ThresholdValueException
from cellphonedb.src.core.utils.subsampler import Subsampler
from cellphonedb.src.database.manager.DatabaseVersionManager import list_local_versions, find_database_for
from cellphonedb.src.exceptions.ParseCountsException import ParseCountsException
from cellphonedb.src.exceptions.ParseMetaException import ParseMetaException
from cellphonedb.src.exceptions.PlotException import PlotException
from cellphonedb.src.exceptions.ReadFileException import ReadFileException
from cellphonedb.src.plotters.r_plotter import dot_plot, heatmaps_plot
from cellphonedb.utils import utils
from rabbit_logger import RabbitAdapter, RabbitLogger
rabbit_logger = RabbitLogger()
try:
s3_access_key = os.environ['S3_ACCESS_KEY']
s3_secret_key = os.environ['S3_SECRET_KEY']
s3_bucket_name = os.environ['S3_BUCKET_NAME']
s3_endpoint = os.environ['S3_ENDPOINT']
rabbit_host = os.environ['RABBIT_HOST']
rabbit_port = os.environ['RABBIT_PORT']
rabbit_user = os.environ['RABBIT_USER']
rabbit_password = os.environ['RABBIT_PASSWORD']
jobs_queue_name = os.environ['RABBIT_JOB_QUEUE']
result_queue_name = os.environ['RABBIT_RESULT_QUEUE']
queue_type = os.environ['QUEUE_TYPE']
except KeyError as e:
rabbit_logger.error('ENVIRONMENT VARIABLE {} not defined. Please set it'.format(e))
exit(1)
verbose = bool(strtobool(os.getenv('VERBOSE', 'true')))
if verbose:
rabbit_logger.setLevel(INFO)
def logger_for_job(job_id):
return RabbitAdapter.logger_for(rabbit_logger, job_id)
def _track_success(f) -> Callable:
@wraps(f)
def wrapper(*args, **kwargs):
logger = kwargs.get('logger', rabbit_logger)
logger.info('calling {} method'.format(f.__name__))
result = f(*args, **kwargs)
logger.info('successfully called {} method'.format(f.__name__))
return result
return wrapper
def create_rabbit_connection():
return pika.BlockingConnection(pika.ConnectionParameters(
host=rabbit_host,
port=rabbit_port,
virtual_host='/',
credentials=credentials
))
s3_resource = boto3.resource('s3', aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint)
s3_client = boto3.client('s3', aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint)
def read_data_from_s3(filename: str, s3_bucket_name: str, index_column_first: bool):
s3_object = s3_client.get_object(Bucket=s3_bucket_name, Key=filename)
return utils.read_data_from_s3_object(s3_object, filename, index_column_first=index_column_first)
def write_data_in_s3(data: pd.DataFrame, filename: str):
result_buffer = io.StringIO()
data.to_csv(result_buffer, index=False, sep='\t')
result_buffer.seek(0)
# TODO: Find more elegant solution (connexion closes after timeout)
s3_client = boto3.client('s3', aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint)
s3_client.put_object(Body=result_buffer.getvalue().encode('utf-8'), Bucket=s3_bucket_name, Key=filename)
def write_image_to_s3(path: str, filename: str):
_io = open(path, 'rb')
# TODO: Find more elegant solution (connexion closes after timeout)
s3_client = boto3.client('s3', aws_access_key_id=s3_access_key,
aws_secret_access_key=s3_secret_key,
endpoint_url=s3_endpoint)
s3_client.put_object(Body=_io, Bucket=s3_bucket_name, Key=filename)
@_track_success
def dot_plot_results(means: str, pvalues: str, rows: str, columns: str, job_id: str):
with tempfile.TemporaryDirectory() as output_path:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(means)[-1]) as means_file:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(pvalues)[-1]) as pvalues_file:
with tempfile.NamedTemporaryFile() as rows_file:
with tempfile.NamedTemporaryFile() as columns_file:
_from_s3_to_temp(means, means_file)
_from_s3_to_temp(pvalues, pvalues_file)
_from_s3_to_temp(rows, rows_file)
_from_s3_to_temp(columns, columns_file)
output_name = 'plot__{}.png'.format(job_id)
dot_plot(means_path=means_file.name,
pvalues_path=pvalues_file.name,
output_path=output_path,
output_name=output_name,
rows=rows_file.name,
columns=columns_file.name)
output_file = os.path.join(output_path, output_name)
if not os.path.exists(output_file):
raise PlotException('Could not generate output file for plot of type dot_plot')
response = {
'job_id': job_id,
'files': {
'plot': output_name,
},
'success': True
}
write_image_to_s3(output_file, output_name)
return response
@_track_success
def heatmaps_plot_results(meta: str, pvalues: str, pvalue: float, job_id: str):
with tempfile.TemporaryDirectory() as output_path:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(pvalues)[-1]) as pvalues_file:
with tempfile.NamedTemporaryFile(suffix=os.path.splitext(meta)[-1]) as meta_file:
_from_s3_to_temp(pvalues, pvalues_file)
_from_s3_to_temp(meta, meta_file)
count_name = 'plot_count__{}.png'.format(job_id)
count_log_name = 'plot_count_log__{}.png'.format(job_id)
count_network_name = 'count_network__{}.txt'.format(job_id)
interactions_count_name = 'interactions_count__{}.txt'.format(job_id)
heatmaps_plot(meta_file=meta_file.name,
pvalues_file=pvalues_file.name,
output_path=output_path,
count_name=count_name,
log_name=count_log_name,
count_network_filename=count_network_name,
interaction_count_filename=interactions_count_name,
pvalue=pvalue)
output_count_file = os.path.join(output_path, count_name)
output_count_log_file = os.path.join(output_path, count_log_name)
output_count_network_file = os.path.join(output_path, count_network_name)
output_interactions_count_file = os.path.join(output_path, interactions_count_name)
if not os.path.exists(output_count_file) \
or not os.path.exists(output_count_log_file) \
or not os.path.exists(output_count_network_file) \
or not os.path.exists(output_interactions_count_file):
raise PlotException('Could not generate output file for plot of type heatmap_plot')
response = {
'job_id': job_id,
'files': {
'count_plot': count_name,
'count_log_plot': count_log_name,
'count_network': count_network_name,
'interactions_sum': interactions_count_name,
},
'success': True
}
write_image_to_s3(output_count_file, count_name)
write_image_to_s3(output_count_log_file, count_log_name)
write_image_to_s3(output_count_network_file, count_network_name)
write_image_to_s3(output_interactions_count_file, interactions_count_name)
return response
def _from_s3_to_temp(key, file):
data = s3_client.get_object(Bucket=s3_bucket_name, Key=key)
file.write(data['Body'].read())
file.seek(0)
return file
@_track_success
def process_plot(method, properties, body, logger) -> dict:
metadata = json.loads(body.decode('utf-8'))
job_id = metadata['job_id']
logger.info('New Plot Queued')
plot_type = metadata.get('type', None)
if plot_type == 'dot_plot':
return dot_plot_results(metadata.get('file_means'),
metadata.get('file_pvalues'),
metadata.get('file_rows', None),
metadata.get('file_columns', None),
job_id
)
if plot_type == 'heatmaps_plot':
return heatmaps_plot_results(metadata.get('file_meta'),
metadata.get('file_pvalues'),
metadata.get('pvalue', 0.05),
job_id
)
return {
'job_id': job_id,
'success': False,
'error': {
'id': 'UnknownPlotType',
'message': 'Given plot type does not exist: {}'.format(plot_type)
}
}
@_track_success
def process_method(method, properties, body, logger) -> dict:
metadata = json.loads(body.decode('utf-8'))
job_id = metadata['job_id']
logger.info('New Job Queued')
meta = read_data_from_s3(metadata['file_meta'], s3_bucket_name, index_column_first=False)
counts = read_data_from_s3(metadata['file_counts'], s3_bucket_name, index_column_first=True)
subsampler = Subsampler(bool(metadata['log']),
int(metadata['num_pc']),
int(metadata['num_cells']) if metadata.get('num_cells', False) else None
) if metadata.get('subsampling', False) else None
database_version = metadata.get('database_version', 'latest')
if database_version not in list_local_versions() + ['latest']:
database_version = 'latest'
app = cpdb_app.create_app(verbose=verbose, database_file=find_database_for(database_version))
if metadata['iterations']:
response = statistical_analysis(app, meta, counts, job_id, metadata, subsampler)
else:
response = non_statistical_analysis(app, meta, counts, job_id, metadata, subsampler)
return response
@_track_success
def statistical_analysis(app, meta, counts, job_id, metadata, subsampler):
pvalues, means, significant_means, deconvoluted = \
app.method.cpdb_statistical_analysis_launcher(meta,
counts,
counts_data=metadata.get('counts_data', 'ensembl'),
threshold=float(metadata['threshold']),
iterations=int(metadata['iterations']),
debug_seed=-1,
threads=4,
result_precision=int(metadata['result_precision']),
pvalue=float(metadata.get('pvalue', 0.05)),
subsampler=subsampler,
)
response = {
'job_id': job_id,
'files': {
'pvalues': 'pvalues_simple_{}.txt'.format(job_id),
'means': 'means_simple_{}.txt'.format(job_id),
'significant_means': 'significant_means_simple_{}.txt'.format(job_id),
'deconvoluted': 'deconvoluted_simple_{}.txt'.format(job_id),
},
'success': True
}
write_data_in_s3(pvalues, response['files']['pvalues'])
write_data_in_s3(means, response['files']['means'])
write_data_in_s3(significant_means, response['files']['significant_means'])
write_data_in_s3(deconvoluted, response['files']['deconvoluted'])
return response
@_track_success
def non_statistical_analysis(app, meta, counts, job_id, metadata, subsampler):
means, significant_means, deconvoluted = \
app.method.cpdb_method_analysis_launcher(meta,
counts,
counts_data=metadata.get('counts_data', 'ensembl'),
threshold=float(metadata['threshold']),
result_precision=int(metadata['result_precision']),
subsampler=subsampler,
)
response = {
'job_id': job_id,
'files': {
'means': 'means_simple_{}.txt'.format(job_id),
'significant_means': 'significant_means_{}.txt'.format(job_id),
'deconvoluted': 'deconvoluted_simple_{}.txt'.format(job_id),
},
'success': True
}
write_data_in_s3(means, response['files']['means'])
write_data_in_s3(significant_means, response['files']['significant_means'])
write_data_in_s3(deconvoluted, response['files']['deconvoluted'])
return response
consume_more_jobs = True
credentials = pika.PlainCredentials(rabbit_user, rabbit_password)
connection = create_rabbit_connection()
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
jobs_runned = 0
while jobs_runned < 3 and consume_more_jobs:
job = channel.basic_get(queue=jobs_queue_name, no_ack=True)
if all(job):
job_id = json.loads(job[2].decode('utf-8'))['job_id']
job_logger = logger_for_job(job_id)
try:
if queue_type == 'plot':
job_response = process_plot(*job, logger=job_logger)
elif queue_type == 'method':
job_response = process_method(*job, logger=job_logger)
else:
raise Exception('Unknown queue type')
# TODO: Find more elegant solution
connection = create_rabbit_connection()
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
channel.basic_publish(exchange='', routing_key=result_queue_name, body=json.dumps(job_response))
job_logger.info('JOB PROCESSED')
except (ReadFileException, ParseMetaException, ParseCountsException, ThresholdValueException,
AllCountsFilteredException, EmptyResultException, PlotException) as e:
error_response = {
'job_id': job_id,
'success': False,
'error': {
'id': str(e),
'message': (' {}.'.format(e.description) if hasattr(e, 'description') and e.description else '') +
(' {}.'.format(e.hint) if hasattr(e, 'hint') and e.hint else '')
}
}
print(traceback.print_exc(file=sys.stdout))
job_logger.error('[-] ERROR DURING PROCESSING JOB')
if connection.is_closed:
connection = create_rabbit_connection()
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
channel.basic_publish(exchange='', routing_key=result_queue_name, body=json.dumps(error_response))
job_logger.error(e)
except Exception as e:
error_response = {
'job_id': job_id,
'success': False,
'error': {
'id': 'unknown_error',
'message': ''
}
}
print(traceback.print_exc(file=sys.stdout))
job_logger.error('[-] ERROR DURING PROCESSING JOB')
if connection.is_closed:
connection = create_rabbit_connection()
channel = connection.channel()
channel.basic_qos(prefetch_count=1)
channel.basic_publish(exchange='', routing_key=result_queue_name, body=json.dumps(error_response))
job_logger.error(e)
jobs_runned += 1
else:
rabbit_logger.debug('Empty queue')
time.sleep(1)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。