diff --git a/contrib/FireDetection/python/README.md b/contrib/FireDetection/python/README.md new file mode 100644 index 0000000000000000000000000000000000000000..3832a76d529d43e34843c14899451f01354854d6 --- /dev/null +++ b/contrib/FireDetection/python/README.md @@ -0,0 +1,81 @@ +# 基于mxBase的高速公路车辆火灾识别 + +## 1 介绍 + +高速公路车辆火灾识别基于 MindX Vision 开发,在 Atlas 300V、Atlas 300V Pro 上进行目标检测。项目主要流程为:通过av模块打开本地视频文件、模拟视频流,然后进行视频解码,解码结果经过模型推理进行火焰和烟雾检测,如果检测到烟雾和火灾则在日志中进行告警。解码后的视频图像会再次编码保存至指定位置。 + +### 1.1 支持的硬件形态 +支持Atlas 300V和Atlas 300V Pro + +### 1.2 支持的版本 + + | MxVision版本 | CANN版本 | Driver/Firmware版本 | + | --------- | ------------------ | -------------- | + | 6.0.RC2 | 8.0.RC2 | 24.1.RC2 | + +### 1.3 软件方案介绍 + +基于MindX Vision的mxBase架构的高速公路车辆火灾识别业务流程为:经av库打开本地视频文件、模拟视频流——>将视频解码成图片——>将图像缩放至满足检测模型要求的大小——>将缩放后的图像输入模型进行车辆火灾识别,如果发生检测到火焰或者烟雾则在日志层面进行告警——>将解码后的视频图像编码保存至指定文件路径。 + +### 1.4 代码目录结构与说明 + +本项目目录如下图所示: + +``` +├── frame_analyzer.py // 视频帧分析 +├── infer_config.json // 服务配置 +├── utils.py +├── main.py +└── README.md +``` +## 2 Python环境依赖 +本项目除了依赖昇腾Driver、Firmware、CANN和MxVision及其要求的配套软件外,还需额外依赖以下python软件: + +| 软件名称 | 版本 | +| -------- | ------ | +| av | 10.0.0 | +| numpy | 1.23.5 | + +## 3 模型下载和转换 +### 3.1 下载模型相关文件 +- **步骤1** 根据[链接](https://mindx.sdk.obs.cn-north-4.myhuaweicloud.com/mindxsdk-referenceapps%20/contrib/FireDetection/models.zip)下载并解压得到firedetection.onnx文件。 + +### 3.2 转换模型格式 +- **步骤1** 设置环境变量 + + . /usr/local/Ascend/ascend-toolkit/set_env.sh # Ascend-cann-toolkit开发套件包默认安装路径,根据实际安装路径修改 +- **步骤2** 将onnx格式模型转换为om格式模型(--soc_version的参数需根据实际NPU型号设置,Atlas 300V和Atlas 300V Pro设备下该参数为Ascend310P3) + + atc --model=./firedetection.onnx --framework=5 --output=./firedetection --input_format=NCHW --input_shape="images:1,3,640,640" --out_nodes="Transpose_217:0;Transpose_233:0;Transpose_249:0" --enable_small_channel=1 --insert_op_conf=./aipp_yolov5.cfg --soc_version=Ascend310P3 --log=info + +## 4 启动和停止高速公路火灾识别服务 +### 4.1 启动高速公路火灾识别服务 + +- **步骤1** 设置环境变量 + + . /usr/local/Ascend/ascend-toolkit/set_env.sh # Ascend-cann-toolkit开发套件包默认安装路径,根据实际安装路径修改 + . ${MX_SDK_HOME}/mxVision/set_env.sh # ${MX_SDK_HOME}替换为用户的SDK安装路径 + +- **步骤2** 设置高速公路车辆火灾识别服务配置(修改infer_config.json文件) ,支持的配置项如下所示 : + + +| 配置项字段 | 配置项含义 | +|:-----------------:|-----------------| +| video_path | 用于火灾识别的视频文件路径 | +| model_path | om模型的路径 | +| device_id | 运行服务时使用的NPU设备编号 | +| skip_frame_number | 指定两次推理的帧间隔数量 | +| video_saved_path | 指定编码后视频保存的文件路径 | +| width | 用于火灾识别的视频文件的宽度 | +| height | 用于火灾识别的视频文件的高度 | + + +*device_id取值范围为[0, NPU设备个数-1],`npu-smi info` 命令可以查看NPU设备个数;skip_frame_number建议根据实际业务需求设置,推荐设置为3;width和height的取值范围为[128, 4096];video_path所指定的视频文件需为H264编码;video_saved_path所指定的文件每次服务启动时会被覆盖重写。 + +- **步骤3** 启动火灾检测服务。火灾检测结果在warning级别日志中体现;编码视频文件保存在配置文件指定的路径下。 + + python3 main.py +### 4.2 停止高速公路火灾识别服务 +- 停止服务有如下两种方式: + + 1.视频文件分析完毕后可自动停止服务。 2.命令行输入Ctrl+C组合键可手动停止服务。 \ No newline at end of file diff --git a/contrib/FireDetection/python/aipp_yolov5.cfg b/contrib/FireDetection/python/aipp_yolov5.cfg new file mode 100644 index 0000000000000000000000000000000000000000..4990f0d9c63f3a97efa422afb925d4eff484c06f --- /dev/null +++ b/contrib/FireDetection/python/aipp_yolov5.cfg @@ -0,0 +1,26 @@ +aipp_op { + aipp_mode : static + related_input_rank : 0 + input_format : YUV420SP_U8 + + src_image_size_w : 640 + src_image_size_h : 640 + crop : false + csc_switch : true + rbuv_swap_switch : false + matrix_r0c0 : 256 + matrix_r0c1 : 0 + matrix_r0c2 : 359 + matrix_r1c0 : 256 + matrix_r1c1 : -88 + matrix_r1c2 : -183 + matrix_r2c0 : 256 + matrix_r2c1 : 454 + matrix_r2c2 : 0 + input_bias_0 : 0 + input_bias_1 : 128 + input_bias_2 : 128 + var_reci_chn_0 : 0.0039216 + var_reci_chn_1 : 0.0039216 + var_reci_chn_2 : 0.0039216 +} diff --git a/contrib/FireDetection/python/frame_analyzer.py b/contrib/FireDetection/python/frame_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..e780c9b63763b25289d53fce30751f61e113b943 --- /dev/null +++ b/contrib/FireDetection/python/frame_analyzer.py @@ -0,0 +1,143 @@ +import os +import numpy as np +from mindx.sdk import base +from mindx.sdk.base import Tensor, Model, Size, Rect, log, ImageProcessor, post, Point +from utils import file_base_check, logger + +MODEL_INPUT_HEIGHT = 640 +MODEL_INPUT_WIDTH = 640 +MODEL_SHAPE = Size(MODEL_INPUT_HEIGHT, MODEL_INPUT_WIDTH) +ANCHORS_SIZE = [[[10, 13], [16, 30], [33, 23]], [[30, 61], [62, 45], [59, 119]], [[116, 90], [156, 198], [373, 326]]] +NMS_THRESHOLD = 0.6 +INDEX_TO_CLASS = {0: "Fire", 1: "Smoke"} + + +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + +def nms(dets, thresh): + x1 = dets[:, 0] # xmin + y1 = dets[:, 1] # ymin + x2 = dets[:, 2] # xmax + y2 = dets[:, 3] # ymax + scores = dets[:, 4] # confidence + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) # 每个bounding box的面积 + order = scores.argsort()[::-1] # 按置信度降序排序 + + keep = [] # 用来保存最后留下来的bounding box + while order.size > 0: + i = order[0] # 置信度最高的bounding box的index + keep.append(i) # 添加本次置信度最高的bounding box的index + + # 当前bbox和剩下bbox之间的交叉区域 + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + # 计算交叉区域的面积 + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + + # 交叉区域面积 / (bbox + 某区域面积 - 交叉区域面积) + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + # 保留交集小于一定阈值的bounding box + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + return keep + + +class FrameAnalyzeModel: + def __init__(self, model_path, device_id): + file_base_check(os.path.realpath(model_path)) + self.model = Model(os.path.realpath(model_path), device_id) + self.image_processor = ImageProcessor(device_id) + + @staticmethod + def __decode_output(output_tensors): + output_np_tensors = [] + for tensor in output_tensors: + tensor.to_host() + output_np_tensors.append(np.array(tensor)) + bounding_box_array = [] + for layer_idx, tensor in enumerate(output_np_tensors): + batch, anchor_num, height, width, box_para = tensor.shape + for height_idx in range(height): + for width_idx in range(width): + for anchor_idx in range(anchor_num): + # Filter unimportant anchor and determine the class of anchor according to the given threshold + objectness = sigmoid(tensor[0, anchor_idx, height_idx, width_idx, 4]) + if objectness < 0.1: + continue + + class_score1 = sigmoid(tensor[0, anchor_idx, height_idx, width_idx, 5]) * objectness + class_score2 = sigmoid(tensor[0, anchor_idx, height_idx, width_idx, 6]) * objectness + if class_score1 < 0.4 and class_score2 < 0.4: + continue + + temp_score = -1 + temp_class_id = -1 + if class_score1 < class_score2: + temp_score = class_score2 + temp_class_id = 1 + else: + temp_score = class_score1 + temp_class_id = 0 + + # Convert relative box info into absolute box info according to prior anchors + temp_x = width_idx + sigmoid(tensor[0, anchor_idx, height_idx, width_idx, 0]) * 2 - 0.5 + temp_y = height_idx + sigmoid(tensor[0, anchor_idx, height_idx, width_idx, 1]) * 2 - 0.5 + temp_width = sigmoid(tensor[0, anchor_idx, height_idx, width_idx, 2]) * \ + sigmoid(tensor[0, anchor_idx, height_idx, width_idx, 2]) * 4 * \ + ANCHORS_SIZE[layer_idx][anchor_idx][0] + temp_height = sigmoid(tensor[0, anchor_idx, height_idx, width_idx, 3]) * \ + sigmoid(tensor[0, anchor_idx, height_idx, width_idx, 3]) * 4 * \ + ANCHORS_SIZE[layer_idx][anchor_idx][1] + + # Convert (x, y, h, w) format into (x0, y0, x1, y1) format + x0 = max(temp_x / width * MODEL_INPUT_WIDTH - temp_width / 2, 0) + y0 = max(temp_y / height * MODEL_INPUT_HEIGHT - temp_height / 2, 0) + x1 = min(temp_x / width * MODEL_INPUT_WIDTH + temp_width / 2, MODEL_INPUT_WIDTH) + y1 = min(temp_y / width * MODEL_INPUT_HEIGHT + temp_height / 2, MODEL_INPUT_HEIGHT) + bounding_box_array.append([x0, y0, x1, y1, temp_score, temp_class_id]) + return np.array(bounding_box_array) + + def infer(self, image): + height_ratio = image.original_height / MODEL_INPUT_HEIGHT + width_ratio = image.original_width / MODEL_INPUT_WIDTH + if image.height != MODEL_INPUT_HEIGHT or image.width != MODEL_INPUT_WIDTH: + image = self.image_processor.resize(image, MODEL_SHAPE, base.huaweiu_high_order_filter) + # model inference + image_tensor = [image.to_tensor()] + output_tensors = self.model.infer(image_tensor) + # decode output results + bounding_box_array = self.__decode_output(output_tensors) + # conduct non max suppression + if bounding_box_array.size != 0: + keep_idx = nms(bounding_box_array, NMS_THRESHOLD) + # correct bounding box bias due to resize operation + bounding_box_array = bounding_box_array[keep_idx, :] + bounding_box_array[:, [0, 2]] *= width_ratio + bounding_box_array[:, [1, 3]] *= height_ratio + return bounding_box_array + + +class FrameAnalyzer: + def __init__(self, model_path, device_id): + self.frame_analyze_model = FrameAnalyzeModel(model_path, device_id) + + @staticmethod + def alarm(analysis_info, frame_id): + for bounding_box in analysis_info: + left_top_point, right_button_point = (int(bounding_box[0]), int(bounding_box[1])),\ + (int(bounding_box[2]), int(bounding_box[3])) + logger.warning("Frame {} detect {}! Confidence: {:.2f}, x0: {:.2f}, y0: {:.2f}, x1: {:.2f}, y1: {:.2f}" + .format(frame_id, INDEX_TO_CLASS[bounding_box[5]], bounding_box[4], left_top_point[0], + left_top_point[1], right_button_point[0], right_button_point[1])) + + def analyze(self, image): + return self.frame_analyze_model.infer(image) diff --git a/contrib/FireDetection/python/infer_config.json b/contrib/FireDetection/python/infer_config.json new file mode 100644 index 0000000000000000000000000000000000000000..160aea4d4b892b2aaf629646687afa1a58acf107 --- /dev/null +++ b/contrib/FireDetection/python/infer_config.json @@ -0,0 +1,9 @@ +{ + "video_path": "./fireDetection.264", + "model_path": "./firedetection.om", + "device_id": 0, + "skip_frame_number": 3, + "video_saved_path": "./output.h264", + "width": 1920, + "height": 1080 +} diff --git a/contrib/FireDetection/python/main.py b/contrib/FireDetection/python/main.py new file mode 100644 index 0000000000000000000000000000000000000000..58aff115bbd939ae1d6f1146c6ad99a0675a5c0a --- /dev/null +++ b/contrib/FireDetection/python/main.py @@ -0,0 +1,139 @@ +import threading +import signal +import time +import av +from mindx.sdk import base +from mindx.sdk.base import VideoDecoder, VideoDecodeConfig,\ + VdecCallBacker, VideoEncoder, VideoEncodeConfig, VencCallBacker +from frame_analyzer import FrameAnalyzer +from utils import infer_config, logger + +decoded_data_queue = [] +analyzed_data_queue = [] +decode_finished_flag = False +SIGNAL_RECEIVED = False + + +class Frame: + def __init__(self, image, frame_id): + self.image = image + self.frame_id = frame_id + + +def stop_handler(signum, frame): + global SIGNAL_RECEIVED + SIGNAL_RECEIVED = True + + +def vdec_callback_func(decoded_image, channel_id, frame_id): + logger.debug('Video decoder output decoded image (channelId:{}, frameId:{}, image.width:{},' + ' image.height:{}, image.format:{})'.format(channel_id, frame_id, decoded_image.width, + decoded_image.height, decoded_image.format)) + # 解码完成的Image类存入列表中 + decoded_data_queue.append(decoded_image) + + +def vdec_thread_func(vdec_config, vdec_callbacker, device_id, rtsp): + global decode_finished_flag + global SIGNAL_RECEIVED + with av.open(rtsp) as container: + count = 0 + # 初始化VideoDecoder + video_decoder = VideoDecoder(vdec_config, vdec_callbacker, device_id, 0) + # 循环取帧解码 + for packet in container.demux(): + if SIGNAL_RECEIVED: + break + if packet.size == 0: + logger.info("Finish to pull rtsp stream.") + SIGNAL_RECEIVED = True + break + logger.debug("send packet:{} ".format(count)) + video_decoder.decode(packet, count) + time.sleep(0.02) + count += 1 + logger.info("There are {} frames in total.".format(count)) + + +# 视频编码回调函数 +def venc_callback_func(output, output_datasize, channel_id, frame_id): + logger.debug('Video encoder output encoded_stream. (type:{}, outDataSize:{}, channelId:{}, frameId:{})' + .format(type(output), output_datasize, channel_id, frame_id)) + with open(infer_config["video_saved_path"], 'ab') as file: + file.write(output) + + +def venc_thread_func(venc_config, venc_callbacker, device_id): + video_encoder = VideoEncoder(venc_config, venc_callbacker, device_id) + i = 0 + global SIGNAL_RECEIVED + while not (SIGNAL_RECEIVED and not decoded_data_queue): + if not decoded_data_queue: + continue + frame_image = decoded_data_queue.pop(0) + if i % infer_config["skip_frame_number"] == 0: + analyzed_data_queue.append(Frame(frame_image, i)) + video_encoder.encode(frame_image, i) + time.sleep(0.02) + i += 1 + logger.info("Venc thread ended.") + + +def analyze_thread_func(model_path, device_id): + frame_analyzer = FrameAnalyzer(model_path, device_id) + global SIGNAL_RECEIVED + while not (SIGNAL_RECEIVED and not analyzed_data_queue): + if not analyzed_data_queue: + continue + frame = analyzed_data_queue.pop(0) + results = frame_analyzer.analyze(frame.image) + if results.size != 0: + frame_analyzer.alarm(results, frame.frame_id) + logger.info("Analyze thread ended.") + + +signal.signal(signal.SIGINT, stop_handler) +if __name__ == '__main__': + base.mx_init() + vdec_callbacker_instance = VdecCallBacker() + vdec_callbacker_instance.registerVdecCallBack(vdec_callback_func) + # # 初始化VideoDecodeConfig类并设置参数 + vdec_conf = VideoDecodeConfig() + vdec_conf.inputVideoFormat = base.h264_main_level + vdec_conf.outputImageFormat = base.nv12 + vdec_conf.width = infer_config["width"] + vdec_conf.height = infer_config["height"] + # 初始化VencCallBacker类并注册回调函数 + venc_callbacker_instance = VencCallBacker() + venc_callbacker_instance.registerVencCallBack(venc_callback_func) + # 初始化VideoEncodeConfig + venc_conf = VideoEncodeConfig() + venc_conf.keyFrameInterval = 50 + venc_conf.srcRate = 30 + venc_conf.maxBitRate = 6000 + venc_conf.ipProp = 30 + + # 创建线程,并传递参数 + vdec = threading.Thread(target=vdec_thread_func, kwargs={'vdec_config': vdec_conf, + 'vdec_callbacker': vdec_callbacker_instance, + "device_id": infer_config["device_id"], + "rtsp": infer_config["video_path"]}) + + venc = threading.Thread(target=venc_thread_func, kwargs={'venc_config': venc_conf, + 'venc_callbacker': venc_callbacker_instance, + "device_id": infer_config["device_id"]}) + + analyze = threading.Thread(target=analyze_thread_func, kwargs={"model_path": infer_config["model_path"], + "device_id": infer_config["device_id"]}) + + # 启动线程 + vdec.start() + venc.start() + analyze.start() + + # 等待执行完毕 + vdec.join() + venc.join() + analyze.join() + + logger.info("Fire detection task ended successfully.") diff --git a/contrib/FireDetection/python/utils.py b/contrib/FireDetection/python/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..38774c6bfd105c09d4b0ab17cf47619e9ec8431b --- /dev/null +++ b/contrib/FireDetection/python/utils.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved. +Description: Common function for utilization. +Author: MindX SDK +Create: 2024 +History: NA +""" +import json +import os +import logging + + +def file_base_check(path: str) -> None: + file_name = os.path.basename(path) + if not path or not os.path.isfile(path): + raise Exception('The file:{} does not exist!'.format(file_name)) + if os.path.islink(path): + raise Exception('The file:{} is link. invalid file!'.format(file_name)) + if not os.access(path, mode=os.R_OK): + raise Exception('The file:{} is unreadable!'.format(file_name)) + + +def read_json_config(json_path: str) -> dict: + file_base_check(json_path) + try: + with open(json_path, "r") as fr: + json_data = json.load(fr) + except json.decoder.JSONDecodeError as e: + raise Exception('json decode error: config file is not a json format file!') from e + finally: + pass + if not isinstance(json_data, dict): + raise Exception('json decode error: config file is not a json format file!') + return json_data + + +def _init(): + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') + logger_instance = logging.getLogger() + file_base_check("./infer_config.json") + infer_config_instance = read_json_config("./infer_config.json") + file_base_check(infer_config_instance["video_path"]) + file_path = infer_config_instance["video_saved_path"] + directory = os.path.dirname(file_path) + if not os.path.exists(directory): + os.makedirs(directory) + if os.path.exists(file_path): + os.remove(file_path) + return logger_instance, infer_config_instance + + +logger, infer_config = _init() \ No newline at end of file