diff --git a/MPC/kcal_python/CMakeLists.txt b/MPC/kcal_python/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..d2b44aa42137a1ca6416b776e6f1b979f1cb572c --- /dev/null +++ b/MPC/kcal_python/CMakeLists.txt @@ -0,0 +1,48 @@ +cmake_minimum_required(VERSION 3.15...4.0) +project(kcal_python) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (DEFINED ENV{CONDA_PREFIX}) + set(Python3_ROOT_DIR $ENV{CONDA_PREFIX}) +endif () + +find_package(Python3 COMPONENTS Interpreter Development REQUIRED) +message(STATUS "Python3_EXECUTABLE: ${Python3_EXECUTABLE}") +message(STATUS "Python3_INCLUDE_DIRS: ${Python3_INCLUDE_DIRS}") +message(STATUS "Python3_LIBRARIES: ${Python3_LIBRARIES}") + +# pybind11 +set(PYBIND11_FINDPYTHON ON) +find_package(pybind11 CONFIG REQUIRED) + +# kcal +set(KCAL_LIBS + data_guard_common + data_guard + hitls_bsl + hitls_crypto + mpc_tee + securec +) +set(LIB_SEARCH_DIR "${CMAKE_CURRENT_LIST_DIR}/lib") +set(KCAL_LIB_PATHS "") +foreach (lib_name IN LISTS KCAL_LIBS) + find_library(found_path + NAMES ${lib_name} + PATHS ${LIB_SEARCH_DIR} + NO_DEFAULT_PATH + ) + if (NOT found_path) + message(FATAL_ERROR "Cannot find ${lib_name} in ${LIB_SEARCH_DIR}") + endif () + list(APPEND KCAL_LIB_PATHS ${found_path}) +endforeach () + +# middleware +set(KCAL_MIDDLEWARE_ROOT ${CMAKE_CURRENT_LIST_DIR}/../middleware) +file(GLOB_RECURSE KCAL_MIDDLEWARE_SRCS ${CMAKE_CURRENT_LIST_DIR}/../middleware/kcal/**/*.cc) + +include_directories(${CMAKE_CURRENT_LIST_DIR}/include) + +add_subdirectory(src) diff --git a/MPC/kcal_python/README.md b/MPC/kcal_python/README.md new file mode 100644 index 0000000000000000000000000000000000000000..86569df3cfceb035dd9610108ab7169c148fb7bf --- /dev/null +++ b/MPC/kcal_python/README.md @@ -0,0 +1,98 @@ +# KCAL中间件 Python 接口封装 + +本文介绍 KCAL 中间件用 Python 进行封装(目前仅提供PSI接口)的项目如何 virtCCA 内部部署和验证, 对外提供以 Python 方式集成的思路 + +## 前置条件 + +1. 需获取 `kcal` 包,含 `include`和`lib`目录,获取链接: [https://www.hikunpeng.com/developer/download](https://gitee.com/link?target=https%3A%2F%2Fwww.hikunpeng.com%2Fdeveloper%2Fdownload) +2. 安装 `Python` 和 `pdm` 包, 以及依赖 `pybind11-3.0.1`版本 +3. 运行环境为`virtCCA cvm,前提是用户已经启动两个 virtCCA 的机密虚机(cvm1、cvm2)` + +## 目录结构介绍 + +当前项目目录如下, kcal 包下载后, 需要进行解压并将 kcal 包内的`include`、`lib`目录放到当前目录下, 然后进行构建 + +```bash +. +|-- CMakeLists.txt # pybind11 包装的项目构建文件 +|-- README.md # 说明文档 +|-- build_native.py # 构建 Python 封装包的脚本 +|-- include # kcal 头文件 +|-- kcal # 实际打包进 whl 里面的目录 +| |-- __init__.py +| |-- kcal.pyi # Python 接口存根文件 +|-- lib # kcal 动态链接库 +|-- pyproject.toml # Python 打包管理说明 +|-- src # pybind11 封装 kcal 中间件的源码 +| |-- CMakeLists.txt +| |-- context_ext.cc +| |-- context_ext.h +| |-- kcal_wrapper.cc # Python 对外接口 +|-- test # 测试目录 + |-- __init__.py + |-- demo.py # 演示示例 + |-- socket_util.py # 简单 socket 网络通信实现 +``` + +## 构建 + +进入到`kcal_python`目录, 然后执行以下操作(前提是已经安装好 `Python` 和 `pdm` 包, 这里推荐使用 `uv` 作为 `Python` 的版本管理工具) + +### 安装依赖 + +1. 安装 uv 工具 + + ```bash + curl -LsSf https://astral.sh/uv/install.sh | sh + ``` + +2. 创建虚拟环境 + + ```bash + uv venv --python 3.11 + source .venv/bin/activate + ``` + +3. 安装 pdm 工具 + + ```bash + uv pip install pdm + ``` + +### 打包 + +该打包过程基于`kcal_python`根目录下面的`kcal`目录进行打包, `kcal`加速库会一并打包进`.whl`包内, 并在导入时自动加载 + +```bash +# 先构建 Python 模块 +pdm run build-native + +# 打包 +pdm build +``` + +执行完后会在`dist`目录下面生成打包好的`.whl`包, 然后执行 `uv pip install dist/*.whl --force-reinstall`即可覆盖安装 + +## 部署 + +只需将生成的`.whl`包导入到`cvm`内, 然后在`cvm`内进行安装, 前提是`cvm`内安装有`Python`, 步骤如下 + +注: 机密虚机启动及连接参考: [https://www.hikunpeng.com/document/detail/zh/kunpengcctrustzone/tee/cVMcont/kunpengtee_16_0027.html](https://www.hikunpeng.com/document/detail/zh/kunpengcctrustzone/tee/cVMcont/kunpengtee_16_0027.html) + +```bash +# 这里以 cvm 内的 /home/admin/dev 作为工作目录 +scp dist/*.whl root@:/home/admin/dev/ + +# 安装, 进入到 cvm 内 +pip install *.whl --force-reinstall +``` + +## 测试 + +为方便演示, 这里仅在一台机器上进行测试, 实际情况在两台分离部署的`cvm`内进行测试, 将`test`目录直接拷贝进`cvm`内的`/home/admin/dev`下, 连接`cvm`, 并打开两个终端, 分别运行以下指令, 即可进行`PSI`的测试, 数据量按需修改`test/demo.py`文件 + +```bash +python test/demo.py --server --host "127.0.0.1" -p 9090 +python test/demo.py --client --host "127.0.0.1" -p 9090 +``` + diff --git a/MPC/kcal_python/build_native.py b/MPC/kcal_python/build_native.py new file mode 100644 index 0000000000000000000000000000000000000000..3be95294eadc6a1df38a4809c5e26b3b408bc23c --- /dev/null +++ b/MPC/kcal_python/build_native.py @@ -0,0 +1,83 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +""" +Usage: + pdm run build-native +""" +import subprocess +import sys +from pathlib import Path +import shutil +import os + +DEPENDENCY_LIB_NAMES = [ + "libdata_guard_common.so", + "libdata_guard.so", + "libhitls_bsl.so", + "libhitls_crypto.so", + "libmpc_tee.so", + "libsecurec.so", +] + + +def find_built_extension(build_dir, expected_basename): + if sys.platform.startswith("win"): + exts = [".pyd", ".dll"] + elif sys.platform == "darwin": + exts = [".so", ".dylib"] + else: + exts = [".so"] + for ext in exts: + for p in build_dir.rglob(f"*{expected_basename}*{ext}"): + return p.resolve() + for ext in exts: + candidate = build_dir / (expected_basename + ext) + if candidate.exists(): + return candidate.resolve() + return None + + +def build_native(): + root = Path(__file__).parent.resolve() + build_dir = root / "build" + pkg_dir = root / "kcal" + pkg_dir.mkdir(exist_ok=True) + + module_name = "kcal" + + # Configure + print("CMake configure...") + build_dir.mkdir(parents=True, exist_ok=True) + cmake_cmd = ["cmake", "-S", str(root), "-B", str(build_dir), "-DCMAKE_BUILD_TYPE=Release"] + conda_prefix = os.environ.get("CONDA_PREFIX") + if conda_prefix: + cmake_cmd.append(f"-DPython3_ROOT_DIR={conda_prefix}") + subprocess.run(cmake_cmd, check=True) + + # Build + print("CMake build...") + build_cmd = ["cmake", "--build", str(build_dir), f"-j{os.cpu_count() - 1}"] + if sys.platform.startswith("win"): + build_cmd += ["--config", "Release"] + subprocess.run(build_cmd, check=True) + + # Copy built extension + print("Searching for built extension...") + built_ext = find_built_extension(build_dir, module_name) + if not built_ext: + raise FileNotFoundError(f"Built extension for {module_name} not found in {build_dir}") + target_name = module_name + (".pyd" if sys.platform.startswith("win") else ".so") + target_path = pkg_dir / target_name + shutil.copy2(built_ext, target_path) + print(f"Copied extension: {built_ext} -> {target_path}") + + # Copy dependency libs + lib_dir = root / "lib" + pkg_lib_dir = pkg_dir / "lib" + shutil.copytree(lib_dir, pkg_lib_dir, dirs_exist_ok=True) + + print("build-native finished. Run `pdm build` to create wheel.") + + +if __name__ == "__main__": + build_native() diff --git a/MPC/kcal_python/kcal/__init__.py b/MPC/kcal_python/kcal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02b7a4defb3ba5a2e583b5b86608b989a6b755a3 --- /dev/null +++ b/MPC/kcal_python/kcal/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import os +import time +import ctypes + +_pkg_dir = os.path.dirname(__file__) +_lib_dir = os.path.join(_pkg_dir, "lib") + +if os.path.isdir(_lib_dir): + old_ld = os.environ.get("LD_LIBRARY_PATH", "") + os.environ["LD_LIBRARY_PATH"] = _lib_dir + os.pathsep + old_ld + + failed_files = [] + for so_file in os.listdir(_lib_dir): + if so_file.endswith(".so"): + so_path = os.path.join(_lib_dir, so_file) + try: + ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) + print(f"Preloaded {so_path} (RTLD_GLOBAL)") + except OSError as e: + failed_files.append(so_path) + + for _ in range(15): + if failed_files: + for so_file in failed_files: + time.sleep(0.1) + try: + ctypes.CDLL(str(so_file), mode=ctypes.RTLD_GLOBAL) + print(f"Preloaded {so_file} (RTLD_GLOBAL)") + failed_files.remove(so_file) + except OSError as e: + print(f"Failed to load {so_file} {e}") + continue + + if failed_files: + raise RuntimeError(f"kcal {failed_files} so lib failed to load") + +from .kcal import * diff --git a/MPC/kcal_python/kcal/kcal.pyi b/MPC/kcal_python/kcal/kcal.pyi new file mode 100644 index 0000000000000000000000000000000000000000..f78e2f8500235973d5b3ebc0054c1f39d25b675c --- /dev/null +++ b/MPC/kcal_python/kcal/kcal.pyi @@ -0,0 +1,625 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +""" +KCAL Python bindings. +""" +from __future__ import annotations +import collections.abc +import typing + +__all__: list[str] = ['ADD', 'AVG', 'Add', 'AlgorithmsType', 'Arithmetic', 'Avg', 'BuildDgString', 'Config', 'Context', + 'ContextBase', 'DIV', 'DUMMY', 'Div', 'DummyMode', 'EQUAL', 'Equal', 'FIX_POINT', 'GREATER', + 'GREATER_EQUAL', 'Greater', 'GreaterEqual', 'Input', 'IsOperatorRegistered', 'LESS', 'LESS_EQUAL', + 'Less', 'LessEqual', 'MAKE_SHARE', 'MAX', 'MIN', 'MUL', 'MakeShare', 'Max', 'Min', 'MpcShare', + 'MpcShareSet', 'Mul', 'NON_FIX_POINT', 'NORMAL', 'NO_EQUAL', 'NoEqual', 'OUTPUT_INDEX', + 'OUTPUT_STRING', 'OperatorBase', 'Output', 'PIR', 'PSI', 'Pir', 'Psi', 'REVEAL_SHARE', + 'RegisterAllOps', 'ReleaseMpcShare', 'ReleaseOutput', 'RevealShare', 'SUB', 'SUM', 'ShareType', 'Sub', + 'Sum', 'TeeMode', 'TeeNodeInfo', 'create_operator'] + + +class Add(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class AlgorithmsType: + """ + Members: + + PSI + + PIR + + MAKE_SHARE + + REVEAL_SHARE + + ADD + + SUB + + MUL + + DIV + + LESS + + LESS_EQUAL + + GREATER + + GREATER_EQUAL + + EQUAL + + NO_EQUAL + + SUM + + AVG + + MAX + + MIN + """ + ADD: typing.ClassVar[AlgorithmsType] # value = + AVG: typing.ClassVar[AlgorithmsType] # value = + DIV: typing.ClassVar[AlgorithmsType] # value = + EQUAL: typing.ClassVar[AlgorithmsType] # value = + GREATER: typing.ClassVar[AlgorithmsType] # value = + GREATER_EQUAL: typing.ClassVar[AlgorithmsType] # value = + LESS: typing.ClassVar[AlgorithmsType] # value = + LESS_EQUAL: typing.ClassVar[AlgorithmsType] # value = + MAKE_SHARE: typing.ClassVar[AlgorithmsType] # value = + MAX: typing.ClassVar[AlgorithmsType] # value = + MIN: typing.ClassVar[AlgorithmsType] # value = + MUL: typing.ClassVar[AlgorithmsType] # value = + NO_EQUAL: typing.ClassVar[AlgorithmsType] # value = + PIR: typing.ClassVar[AlgorithmsType] # value = + PSI: typing.ClassVar[AlgorithmsType] # value = + REVEAL_SHARE: typing.ClassVar[AlgorithmsType] # value = + SUB: typing.ClassVar[AlgorithmsType] # value = + SUM: typing.ClassVar[AlgorithmsType] # value = + __members__: typing.ClassVar[dict[ + str, AlgorithmsType]] # value = {'PSI': , 'PIR': , 'MAKE_SHARE': , 'REVEAL_SHARE': , 'ADD': , 'SUB': , 'MUL': , 'DIV': , 'LESS': , 'LESS_EQUAL': , 'GREATER': , 'GREATER_EQUAL': , 'EQUAL': , 'NO_EQUAL': , 'SUM': , 'AVG': , 'MAX': , 'MIN': } + + def __eq__(self, other: typing.Any) -> bool: + ... + + def __getstate__(self) -> int: + ... + + def __hash__(self) -> int: + ... + + def __index__(self) -> int: + ... + + def __init__(self, value: typing.SupportsInt) -> None: + ... + + def __int__(self) -> int: + ... + + def __ne__(self, other: typing.Any) -> bool: + ... + + def __repr__(self) -> str: + ... + + def __setstate__(self, state: typing.SupportsInt) -> None: + ... + + def __str__(self) -> str: + ... + + @property + def name(self) -> str: + ... + + @property + def value(self) -> int: + ... + + +class Arithmetic(OperatorBase): + pass + + +class Avg(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class Config: + useSMAlg: bool + + def __init__(self) -> None: + ... + + @property + def fixBits(self) -> int: + ... + + @fixBits.setter + def fixBits(self, arg0: typing.SupportsInt) -> None: + ... + + @property + def nodeId(self) -> int: + ... + + @nodeId.setter + def nodeId(self, arg0: typing.SupportsInt) -> None: + ... + + @property + def threadCount(self) -> int: + ... + + @threadCount.setter + def threadCount(self, arg0: typing.SupportsInt) -> None: + ... + + @property + def worldSize(self) -> int: + ... + + @worldSize.setter + def worldSize(self, arg0: typing.SupportsInt) -> None: + ... + + +class Context: + @staticmethod + def Create(arg0: Config, arg1: collections.abc.Callable, arg2: collections.abc.Callable) -> Context: + ... + + def __init__(self) -> None: + ... + + +class ContextBase: + def GetConfig(self) -> Config: + ... + + def GetWorldSize(self) -> int: + ... + + def IsValid(self) -> bool: + ... + + def NodeId(self) -> int: + ... + + def __init__(self) -> None: + ... + + +class Div(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class DummyMode: + """ + Members: + + NORMAL + + DUMMY + """ + DUMMY: typing.ClassVar[DummyMode] # value = + NORMAL: typing.ClassVar[DummyMode] # value = + __members__: typing.ClassVar[ + dict[str, DummyMode]] # value = {'NORMAL': , 'DUMMY': } + + def __eq__(self, other: typing.Any) -> bool: + ... + + def __getstate__(self) -> int: + ... + + def __hash__(self) -> int: + ... + + def __index__(self) -> int: + ... + + def __init__(self, value: typing.SupportsInt) -> None: + ... + + def __int__(self) -> int: + ... + + def __ne__(self, other: typing.Any) -> bool: + ... + + def __repr__(self) -> str: + ... + + def __setstate__(self, state: typing.SupportsInt) -> None: + ... + + def __str__(self) -> str: + ... + + @property + def name(self) -> str: + ... + + @property + def value(self) -> int: + ... + + +class Equal(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class Greater(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class GreaterEqual(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class Input: + @staticmethod + def Create() -> Input: + ... + + def Fill(self, arg0: collections.abc.Sequence[str]) -> None: + ... + + def Get(self) -> DG_TeeInput: + ... + + def Set(self, arg0: DG_TeeInput) -> None: + ... + + def Size(self) -> int: + ... + + @typing.overload + def __init__(self) -> None: + ... + + @typing.overload + def __init__(self, arg0: DG_TeeInput) -> None: + ... + + +class Less(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class LessEqual(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class MakeShare(Arithmetic): + def Run(self, arg0: Input, arg1: typing.SupportsInt, arg2: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class Max(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class Min(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class MpcShare: + @staticmethod + def Create() -> MpcShare: + ... + + def Get(self) -> DG_MpcShare: + ... + + def Set(self, arg0: DG_MpcShare) -> None: + ... + + def Size(self) -> int: + ... + + def Type(self) -> ShareType: + ... + + @typing.overload + def __init__(self) -> None: + ... + + @typing.overload + def __init__(self, arg0: DG_MpcShare) -> None: + ... + + +class MpcShareSet: + @staticmethod + def Create(arg0: collections.abc.Sequence[MpcShare]) -> MpcShareSet: + ... + + def Get(self) -> DG_MpcShareSet: + ... + + def __init__(self) -> None: + ... + + +class Mul(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class NoEqual(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class OperatorBase: + def GetType(self) -> AlgorithmsType: + ... + + +class Pir(OperatorBase): + def ClientQuery(self, arg0: Input, arg1: Input, arg2: DummyMode) -> int: + ... + + def ServerAnswer(self) -> int: + ... + + def ServerPreProcess(self, arg0: DG_PairList) -> int: + ... + + def __init__(self) -> None: + ... + + +class Psi(OperatorBase): + def Run(self, arg0: Input, arg1: Input, arg2: TeeMode) -> int: + ... + + def __init__(self) -> None: + ... + + +class RevealShare(Arithmetic): + def Run(self, arg0: MpcShare, arg1: Input) -> int: + ... + + def __init__(self) -> None: + ... + + +class ShareType: + """ + Members: + + FIX_POINT + + NON_FIX_POINT + """ + FIX_POINT: typing.ClassVar[ShareType] # value = + NON_FIX_POINT: typing.ClassVar[ShareType] # value = + __members__: typing.ClassVar[dict[ + str, ShareType]] # value = {'FIX_POINT': , 'NON_FIX_POINT': } + + def __eq__(self, other: typing.Any) -> bool: + ... + + def __getstate__(self) -> int: + ... + + def __hash__(self) -> int: + ... + + def __index__(self) -> int: + ... + + def __init__(self, value: typing.SupportsInt) -> None: + ... + + def __int__(self) -> int: + ... + + def __ne__(self, other: typing.Any) -> bool: + ... + + def __repr__(self) -> str: + ... + + def __setstate__(self, state: typing.SupportsInt) -> None: + ... + + def __str__(self) -> str: + ... + + @property + def name(self) -> str: + ... + + @property + def value(self) -> int: + ... + + +class Sub(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class Sum(Arithmetic): + def Run(self, arg0: MpcShareSet, arg1: MpcShare) -> int: + ... + + def __init__(self) -> None: + ... + + +class TeeMode: + """ + Members: + + OUTPUT_INDEX + + OUTPUT_STRING + """ + OUTPUT_INDEX: typing.ClassVar[TeeMode] # value = + OUTPUT_STRING: typing.ClassVar[TeeMode] # value = + __members__: typing.ClassVar[dict[ + str, TeeMode]] # value = {'OUTPUT_INDEX': , 'OUTPUT_STRING': } + + def __eq__(self, other: typing.Any) -> bool: + ... + + def __getstate__(self) -> int: + ... + + def __hash__(self) -> int: + ... + + def __index__(self) -> int: + ... + + def __init__(self, value: typing.SupportsInt) -> None: + ... + + def __int__(self) -> int: + ... + + def __ne__(self, other: typing.Any) -> bool: + ... + + def __repr__(self) -> str: + ... + + def __setstate__(self, state: typing.SupportsInt) -> None: + ... + + def __str__(self) -> str: + ... + + @property + def name(self) -> str: + ... + + @property + def value(self) -> int: + ... + + +class TeeNodeInfo: + def __init__(self) -> None: + ... + + @property + def nodeId(self) -> int: + ... + + @nodeId.setter + def nodeId(self, arg0: typing.SupportsInt) -> None: + ... + + +def BuildDgString(arg0: collections.abc.Sequence[str]) -> typing.Any: + ... + + +def IsOperatorRegistered(arg0: AlgorithmsType) -> bool: + ... + + +def RegisterAllOps() -> None: + ... + + +def ReleaseMpcShare(arg0: DG_MpcShare) -> None: + ... + + +def ReleaseOutput(arg0: DG_TeeInput) -> None: + ... + + +def create_operator(arg0: ContextBase, arg1: AlgorithmsType) -> OperatorBase: + ... + + +ADD: AlgorithmsType # value = +AVG: AlgorithmsType # value = +DIV: AlgorithmsType # value = +DUMMY: DummyMode # value = +EQUAL: AlgorithmsType # value = +FIX_POINT: ShareType # value = +GREATER: AlgorithmsType # value = +GREATER_EQUAL: AlgorithmsType # value = +LESS: AlgorithmsType # value = +LESS_EQUAL: AlgorithmsType # value = +MAKE_SHARE: AlgorithmsType # value = +MAX: AlgorithmsType # value = +MIN: AlgorithmsType # value = +MUL: AlgorithmsType # value = +NON_FIX_POINT: ShareType # value = +NORMAL: DummyMode # value = +NO_EQUAL: AlgorithmsType # value = +OUTPUT_INDEX: TeeMode # value = +OUTPUT_STRING: TeeMode # value = +PIR: AlgorithmsType # value = +PSI: AlgorithmsType # value = +REVEAL_SHARE: AlgorithmsType # value = +SUB: AlgorithmsType # value = +SUM: AlgorithmsType # value = +Output = Input diff --git a/MPC/kcal_python/pyproject.toml b/MPC/kcal_python/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..13e423cc68ab7167b656dee652f3f35525bdafba --- /dev/null +++ b/MPC/kcal_python/pyproject.toml @@ -0,0 +1,20 @@ +[project] +name = "kcal" +version = "0.1.0" +description = "Python bindings for kcal library (example, CMake + pybind11 + PDM)" +readme = "README.md" +requires-python = ">=3.8" +dependencies = ["pdm"] + +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[tool.pdm.build] +# ensure package dir and README are included in wheel +includes = ["kcal/", "README.md"] + +[tool.pdm.scripts] +# run `pdm run build-native` to configure/build/copy/generate-stubs +build-native = { call = "build_native:build_native" } +clean = { shell = "rm -rf build/ __pycache__/ *.egg-info/ dist/ CMakeFiles/ _skbuild/" } diff --git a/MPC/kcal_python/src/CMakeLists.txt b/MPC/kcal_python/src/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2827aa9c83fd458be309aa2456b217b204180d2 --- /dev/null +++ b/MPC/kcal_python/src/CMakeLists.txt @@ -0,0 +1,6 @@ +pybind11_add_module(kcal kcal_wrapper.cc context_ext.cc ${KCAL_MIDDLEWARE_SRCS}) +target_include_directories(kcal PRIVATE + $ + $ +) +target_link_libraries(kcal PRIVATE ${KCAL_LIB_PATHS} pthread) \ No newline at end of file diff --git a/MPC/kcal_python/src/context_ext.cc b/MPC/kcal_python/src/context_ext.cc new file mode 100644 index 0000000000000000000000000000000000000000..02de4405efbc2668e643d6fd1c86f69e7886f70c --- /dev/null +++ b/MPC/kcal_python/src/context_ext.cc @@ -0,0 +1,67 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +#include "context_ext.h" + +namespace kcal { + +ContextExt *ContextExt::currentContext_ = nullptr; + +ContextExt::ContextExt(SendCallback sendCb, RecvCallback recvCb) + : sendCallback_(std::move(sendCb)), + recvCallback_(std::move(recvCb)) +{ +} + +ContextExt::~ContextExt() +{ + if (currentContext_ == this) { + currentContext_ = nullptr; + } +} + +std::shared_ptr ContextExt::Create(KCAL_Config config, SendCallback sendCb, RecvCallback recvCb) +{ + auto ctx = std::shared_ptr(new ContextExt(std::move(sendCb), std::move(recvCb))); + + currentContext_ = ctx.get(); + + TEE_NET_RES net_res{}; + net_res.funcSendData = &ContextExt::SendDataThunk; + net_res.funcRecvData = &ContextExt::RecvDataThunk; + + ctx->kcalCtx_ = Context::Create(config, &net_res); + if (!ctx->kcalCtx_) { + currentContext_ = nullptr; + return nullptr; + } + + return ctx; +} + +int ContextExt::SendDataThunk(TeeNodeInfo *nodeInfo, unsigned char *buf, u64 len) +{ + if (!currentContext_ || !currentContext_->sendCallback_) { + return -1; + } + + try { + return currentContext_->sendCallback_(*nodeInfo, buf, len); + } catch (const std::exception &e) { + return -1; + } +} + +int ContextExt::RecvDataThunk(TeeNodeInfo *nodeInfo, unsigned char *buf, u64 *len) +{ + if (!currentContext_ || !currentContext_->recvCallback_) { + return -1; + } + + try { + return currentContext_->recvCallback_(*nodeInfo, buf, *len); + } catch (const std::exception &e) { + return -1; + } +} + +} // namespace kcal \ No newline at end of file diff --git a/MPC/kcal_python/src/context_ext.h b/MPC/kcal_python/src/context_ext.h new file mode 100644 index 0000000000000000000000000000000000000000..9ad3ff418e358fd769594afe3fac250f622f0993 --- /dev/null +++ b/MPC/kcal_python/src/context_ext.h @@ -0,0 +1,40 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "kcal/core/context.h" + +namespace kcal { + +class ContextExt { +public: + using SendCallback = std::function; + using RecvCallback = std::function; + + static std::shared_ptr Create(KCAL_Config config, SendCallback sendCb, RecvCallback recvCb); + + ContextExt() = default; + ~ContextExt(); + + std::shared_ptr GetKcalContext() const { return kcalCtx_; } + + static ContextExt *GetCurrentContext() { return currentContext_; } + +private: + ContextExt(SendCallback sendCb, RecvCallback recvCb); + + static ContextExt *currentContext_; + + SendCallback sendCallback_; + RecvCallback recvCallback_; + std::shared_ptr kcalCtx_; + + static int SendDataThunk(TeeNodeInfo *nodeInfo, unsigned char *buf, u64 len); + static int RecvDataThunk(TeeNodeInfo *nodeInfo, unsigned char *buf, u64 *len); +}; + +} // namespace kcal \ No newline at end of file diff --git a/MPC/kcal_python/src/kcal_wrapper.cc b/MPC/kcal_python/src/kcal_wrapper.cc new file mode 100644 index 0000000000000000000000000000000000000000..a03343c459c58e9b942b491522f38b53d9ea187d --- /dev/null +++ b/MPC/kcal_python/src/kcal_wrapper.cc @@ -0,0 +1,251 @@ +// Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +#include +#include +#include + +#include "context_ext.h" +#include "kcal/core/operator_base.h" +#include "kcal/core/operator_manager.h" +#include "kcal/operator/all_operator_register.h" +#include "kcal/operator/kcal_psi.h" +#include "kcal/utils/io.h" + +namespace py = pybind11; + +namespace kcal { + +namespace { + +void FeedKcalInput(const py::list &pyList, io::KcalInput *kcalInput) +{ + auto *dgString = new (std::nothrow) DG_String[pyList.size()]; + if (!dgString) { + throw std::bad_alloc(); + } + for (size_t i = 0; i < pyList.size(); ++i) { + if (!PyUnicode_Check(pyList[i].ptr())) { + throw std::runtime_error("need str"); + } + + Py_ssize_t sz; + const char *utf8 = PyUnicode_AsUTF8AndSize(pyList[i].ptr(), &sz); + if (!utf8) { + throw std::bad_alloc(); + } + + dgString[i].str = strdup(utf8); + dgString[i].size = static_cast(sz) + 1; + } + DG_TeeInput **internalInput = kcalInput->GetSecondaryPointer(); + (*internalInput)->data.strings = dgString; + (*internalInput)->size = pyList.size(); + (*internalInput)->dataType = MPC_STRING; +} + +void FeedPsiOutput(io::KcalOutput &kcalOutput, py::list &pyList, DG_TeeMode mode) +{ + auto *outPtr = kcalOutput.Get(); + for (size_t i = 0; i < outPtr->size; ++i) { + if (mode == TEE_OUTPUT_INDEX) { + pyList.append(outPtr->data.u64Numbers[i]); + } else if (mode == TEE_OUTPUT_STRING) { + pyList.append(outPtr->data.strings[i].str); + } + } +} + +} // namespace + +class PyCallbackAdapter { +public: + static int PySendCallback(const TeeNodeInfo &nodeInfo, const uint8_t *data, size_t dataLen, + const py::function &pySendFunc) + { + if (!data) { + return 0; + } + try { + py::dict nodeInfoDict; + nodeInfoDict["nodeId"] = nodeInfo.nodeId; + + // zero-copy + py::memoryview dataMview = py::memoryview::from_buffer( + const_cast(data), {static_cast(dataLen)}, {sizeof(uint8_t)}); + + py::object result = pySendFunc(nodeInfoDict, dataMview); + return result.cast(); + } catch (const py::error_already_set &e) { + py::print("Python send callback error:", e.what()); + return -1; + } catch (const std::exception &e) { + py::print("Send callback error:", e.what()); + return -1; + } + } + + static int PyRecvCallback(const TeeNodeInfo &nodeInfo, uint8_t *buffer, size_t maxLen, + const py::function &pyRecvFunc) + { + if (!buffer) { + return 0; + } + try { + py::dict nodeInfoDict; + nodeInfoDict["nodeId"] = nodeInfo.nodeId; + + // zero-copy + py::memoryview bufferMview = + py::memoryview::from_buffer(buffer, {static_cast(maxLen)}, {sizeof(uint8_t)}, false); + + py::object result = pyRecvFunc(nodeInfoDict, bufferMview); + if (result.is_none()) { + return -1; + } + return result.cast(); + } catch (const py::error_already_set &e) { + py::print("Python recv callback error:", e.what()); + return -1; + } catch (const std::exception &e) { + py::print("Recv callback error:", e.what()); + return -1; + } + } +}; + +void BindIoClasses(py::module_ &m) +{ + py::class_(m, "MpcShare") + .def(py::init<>()) + .def(py::init()) + .def_static("Create", &io::KcalMpcShare::Create, py::return_value_policy::take_ownership) + .def("Set", &io::KcalMpcShare::Set) + .def("Get", [](io::KcalMpcShare &self) -> DG_MpcShare* { return self.Get(); }, + py::return_value_policy::reference) + .def("Size", &io::KcalMpcShare::Size) + .def("Type", &io::KcalMpcShare::Type); + + py::class_(m, "MpcShareSet") + .def(py::init<>()) + .def_static("Create", + [](const std::vector &shares) { + return io::KcalMpcShareSet::Create(shares); + }, + py::return_value_policy::take_ownership) + .def("Get", [](io::KcalMpcShareSet &self) -> DG_MpcShareSet* { return self.Get(); }, + py::return_value_policy::reference); + + py::class_(m, "Input") + .def(py::init<>()) + .def(py::init()) + .def_static("Create", &io::KcalInput::Create, py::return_value_policy::take_ownership) + .def("Set", &io::KcalInput::Set) + .def("Get", &io::KcalInput::Get, py::return_value_policy::reference) + .def("Fill", &io::KcalInput::Fill) + .def("Size", &io::KcalInput::Size); + + // Alias of Input + m.attr("Output") = m.attr("Input"); +} + +void BindOtherOperators(py::module_ &m) +{ + // PSI + py::class_>(m, "Psi") + .def(py::init<>()) + .def("run", [](Psi &self, const py::list &input, py::list &output, DG_TeeMode mode) -> int { + std::unique_ptr kcalInput(io::KcalInput::Create()); + FeedKcalInput(input, kcalInput.get()); + io::KcalOutput kcalOutput; + int ret = self.Run(kcalInput->Get(), kcalOutput.GetSecondaryPointer(), mode); + FeedPsiOutput(kcalOutput, output, mode); + return ret; + }); +} + +PYBIND11_MODULE(kcal, m) +{ + m.doc() = "KCAL Python bindings."; + + py::enum_(m, "AlgorithmsType") + .value("PSI", KCAL_AlgorithmsType::PSI) + .export_values(); + + py::enum_(m, "TeeMode") + .value("OUTPUT_INDEX", TEE_OUTPUT_INDEX) + .value("OUTPUT_STRING", TEE_OUTPUT_STRING) + .export_values(); + + py::enum_(m, "DummyMode").value("NORMAL", NORMAL).value("DUMMY", DUMMY).export_values(); + + py::enum_(m, "ShareType") + .value("FIX_POINT", FIX_POINT) + .value("NON_FIX_POINT", NON_FIX_POINT) + .export_values(); + + py::class_(m, "TeeNodeInfo").def(py::init<>()).def_readwrite("nodeId", &TeeNodeInfo::nodeId); + + py::class_(m, "Config") + .def(py::init<>()) + .def_readwrite("nodeId", &KCAL_Config::nodeId) + .def_readwrite("fixBits", &KCAL_Config::fixBits) + .def_readwrite("threadCount", &KCAL_Config::threadCount) + .def_readwrite("worldSize", &KCAL_Config::worldSize) + .def_readwrite("useSMAlg", &KCAL_Config::useSMAlg); + + py::class_> contextClass(m, "ContextBase"); + contextClass.def(py::init<>()) + .def("GetWorldSize", &Context::GetWorldSize) + .def("NodeId", &Context::NodeId) + .def("IsValid", &Context::IsValid) + .def("GetConfig", &Context::GetConfig); + + py::class_>(m, "Context") + .def(py::init<>()) + .def_static("create", [](KCAL_Config config, py::function sendCb, py::function recvCb) { + auto cppSendCb = [sendCb](const TeeNodeInfo &nodeInfo, const uint8_t *data, size_t dataLen) { + return PyCallbackAdapter::PySendCallback(nodeInfo, data, dataLen, sendCb); + }; + + auto cppRecvCb = [recvCb](const TeeNodeInfo &nodeInfo, uint8_t *buffer, size_t maxLen) { + return PyCallbackAdapter::PyRecvCallback(nodeInfo, buffer, maxLen, recvCb); + }; + + return ContextExt::Create(config, cppSendCb, cppRecvCb); + }); + + BindIoClasses(m); + + py::class_>(m, "OperatorBase").def("GetType", &OperatorBase::GetType); + + BindOtherOperators(m); + + m.def("create_operator", + [](const std::shared_ptr &context, KCAL_AlgorithmsType type) -> std::shared_ptr { + switch (type) { + case KCAL_AlgorithmsType::PSI: + return OperatorManager::CreateOperator(context->GetKcalContext()); + default: + throw std::runtime_error("Unsupported operator type"); + } + }); + + m.def("is_op_registered", &OperatorManager::IsOperatorRegistered); + + m.def("register_all_ops", &RegisterAllOps); + + m.def("build_dg_string", [](const std::vector &strings) -> py::object { + DG_String *dg = nullptr; + int ret = io::DataHelper::BuildDgString(strings, &dg); + if (ret != 0) { + throw std::runtime_error("BuildDgString failed"); + } + return py::cast(dg); + }); + + m.def("release_output", [](DG_TeeOutput *output) { io::DataHelper::ReleaseOutput(&output); }); + + m.def("release_mpc_share", [](DG_MpcShare *share) { io::DataHelper::ReleaseMpcShare(&share); }); +} + +} // namespace kcal diff --git a/MPC/kcal_python/test/__init__.py b/MPC/kcal_python/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f65c534b718544536798b2ca9159a6f9d509f1c5 --- /dev/null +++ b/MPC/kcal_python/test/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. diff --git a/MPC/kcal_python/test/demo.py b/MPC/kcal_python/test/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..f6471dbb581bfba7749e55cd84d160d27c286cce --- /dev/null +++ b/MPC/kcal_python/test/demo.py @@ -0,0 +1,94 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +from __future__ import annotations + +import socket +import sys + +import kcal +import argparse + +import socket_util + +_client_socket = None +_server_socket = None + +kcal.register_all_ops() + +""" +server: + nodeId: 0 + socket: _client_socket +client: + nodeId: 1 + socket: _server_socket +""" + + +def get_fd(node_info: dict) -> socket.socket: + return _server_socket if node_info['nodeId'] == 0 else _client_socket + + +def on_send_data(node_info: dict, data_buffer: memoryview) -> int: + s = get_fd(node_info) + return socket_util.send_data(s, data_buffer) + + +def on_recv_data(node_info: dict, buffer: memoryview) -> int: + s = get_fd(node_info) + return socket_util.recv_data(s, buffer) + + +def psi_demo(is_server: bool): + config = kcal.Config() + config.nodeId = 0 if is_server else 1 + config.worldSize = 2 + config.fixBits = 3 + config.threadCount = 32 + config.useSMAlg = False + + context = kcal.Context.create(config, on_send_data, on_recv_data) + + op = kcal.create_operator(context, kcal.AlgorithmsType.PSI) + + input0 = ["4", "3", "2", "1"] + input1 = ["1", "3", "4", "5"] + output = [] + import time + start_time = time.time() + if is_server: + op.run(input0, output, kcal.TeeMode.OUTPUT_INDEX) + else: + op.run(input1, output, kcal.TeeMode.OUTPUT_INDEX) + print(len(output)) + end_time = time.time() + duration_ms = (end_time - start_time) * 1000 # ms + print(f"run cost: {duration_ms:.2f} ms") + + +def main(argv=None): + parser = argparse.ArgumentParser(description="KCAL python wrapper demo.") + try: + group = parser.add_mutually_exclusive_group(required=True) + group.add_argument("--server", action="store_true", default=False, help="start server") + group.add_argument("--client", action="store_true", default=False, help="start client") + parser.add_argument("--host", type=str, default="127.0.0.1") + parser.add_argument("-p", "--port", type=int, required=True) + args = parser.parse_args(argv) + except argparse.ArgumentParser: + parser.print_help() + sys.exit(1) + + global _client_socket, _server_socket + if args.server: + _client_socket = socket_util.init_server(args.host, args.port) + psi_demo(True) + _client_socket.close() + elif args.client: + _server_socket = socket_util.init_client(args.host, args.port) + psi_demo(False) + _server_socket.close() + + +if __name__ == "__main__": + main() diff --git a/MPC/kcal_python/test/socket_util.py b/MPC/kcal_python/test/socket_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f77c74bf441157053d91d709f4768671faa1b24e --- /dev/null +++ b/MPC/kcal_python/test/socket_util.py @@ -0,0 +1,60 @@ +# Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved. + +import socket +import struct + + +def init_server(host: str, port: int) -> socket.socket: + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + server_socket.bind((host, port)) + server_socket.listen(1) + print(f"Server listening on {host}:{port}") + c, addr = server_socket.accept() + print(f"{addr} connected") + return c + + +def init_client(host: str, port: int) -> socket.socket: + client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + client_socket.connect((host, port)) + print(f"Connected to server {host}:{port}") + return client_socket + + +def send_data(sock: socket.socket, data: memoryview) -> int: + total_sent = 0 + data_len = len(data) + + sock.sendall(struct.pack('!I', data_len)) + + while total_sent < data_len: + sent = sock.send(data[total_sent:]) + if sent == 0: + raise RuntimeError("Socket connection broken") + total_sent += sent + + return 0 + + +def recv_data(sock: socket.socket, buffer: memoryview) -> int: + len_data = sock.recv(4) + if not len_data: + return 0 + + data_len = struct.unpack('!I', len_data)[0] + + if data_len > len(buffer): + raise ValueError(f"Buffer too small: {len(buffer)} < {data_len}") + + total_received = 0 + while total_received < data_len: + remaining = data_len - total_received + chunk = sock.recv(min(4096, remaining)) + if not chunk: + raise RuntimeError("Socket connection broken") + + buffer[total_received:total_received + len(chunk)] = chunk + total_received += len(chunk) + + return 0 diff --git a/MPC/middleware/CMakeLists.txt b/MPC/middleware/CMakeLists.txt index 4dc998808e69486dc7e4a5db14cd072a92f66a7b..c3b106699bb692662f1d35a510cff8ff6debc66a 100644 --- a/MPC/middleware/CMakeLists.txt +++ b/MPC/middleware/CMakeLists.txt @@ -4,13 +4,16 @@ project(kcal_middleware LANGUAGES C CXX) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) -find_package(libkcal) +find_package(libkcal REQUIRED) -include_directories(${CMAKE_CURRENT_SOURCE_DIR}) - -add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/kcal) +add_subdirectory(kcal) install(DIRECTORY kcal DESTINATION include FILES_MATCHING PATTERN "*.h" ) + +install(EXPORT kcal_middlewareTargets + FILE kcal_middleConfig.cmake + DESTINATION lib/cmake/kcal_middle +) \ No newline at end of file diff --git a/MPC/middleware/kcal/CMakeLists.txt b/MPC/middleware/kcal/CMakeLists.txt index a66297233941522b6e1d0f3d03506959b4890c89..ce20b9605cd132a23c75ee139a48e3213506f6cc 100644 --- a/MPC/middleware/kcal/CMakeLists.txt +++ b/MPC/middleware/kcal/CMakeLists.txt @@ -1,9 +1,15 @@ file(GLOB_RECURSE COMMON_SRCS ${CMAKE_CURRENT_LIST_DIR}/*.cc) add_library(kcal_middle ${COMMON_SRCS}) +target_include_directories(kcal_middle PUBLIC + $ + $ +) target_link_libraries(kcal_middle PUBLIC lib_kcal) install(TARGETS kcal_middle + EXPORT kcal_middlewareTargets LIBRARY DESTINATION lib ARCHIVE DESTINATION lib + INCLUDES DESTINATION include )