diff --git a/examples/community/homomorphic_inference/BestCheckpoint/resnet20-best-609.ckpt b/examples/community/homomorphic_inference/BestCheckpoint/resnet20-best-609.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..7051b91e917b826506b87711c7a327a0842d8e18 Binary files /dev/null and b/examples/community/homomorphic_inference/BestCheckpoint/resnet20-best-609.ckpt differ diff --git a/examples/community/homomorphic_inference/BestCheckpoint/resnet20-best-804.ckpt b/examples/community/homomorphic_inference/BestCheckpoint/resnet20-best-804.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..e3abfd1fb60da0ad3a6a58fdedd3abb734bad4a0 Binary files /dev/null and b/examples/community/homomorphic_inference/BestCheckpoint/resnet20-best-804.ckpt differ diff --git a/examples/community/homomorphic_inference/BestCheckpoint/resnet20-best-834.ckpt b/examples/community/homomorphic_inference/BestCheckpoint/resnet20-best-834.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..f6e0aca53f8a07bc6510f1b81e95933b0eb343c8 Binary files /dev/null and b/examples/community/homomorphic_inference/BestCheckpoint/resnet20-best-834.ckpt differ diff --git a/examples/community/homomorphic_inference/BestCheckpoint/resnet50-best.ckpt b/examples/community/homomorphic_inference/BestCheckpoint/resnet50-best.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..e0e7ab4b3ade3d463dc40094ec77389fe77cdb88 Binary files /dev/null and b/examples/community/homomorphic_inference/BestCheckpoint/resnet50-best.ckpt differ diff --git a/examples/community/homomorphic_inference/README.md b/examples/community/homomorphic_inference/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e8cb13be7f28d2b0f714110fdf3cf13ef94a9942 --- /dev/null +++ b/examples/community/homomorphic_inference/README.md @@ -0,0 +1,29 @@ +# 基于全同态加密的隐私推理 +本项目基于mindspore框架进行多项式残差网络训练,所训练多项式模型用于CKKS密文推理。基于CKKS的密文推理框架来自README中的原型论文,开源库地址:https://github.com/snu-ccl/FHE-MP-CNN + +## 原型论文 + +Eunsang Lee, Joon-Woo Lee, Junghyun Lee, Young-Sik Kim, Yongjune Kim, JongSeon No, and Woosuk Choi. Low-complexity deep convolutional neural networks on fully homomorphic encryption using multiplexed parallel convolutions. In International Conference on Machine Learning, pages 12403–12422. PMLR, 2022. +## 环境要求 + +Mindspore >= 1.9 + + +## 脚本说明 + +```markdown +├── README.md +├── BestCheckpoint //训练后模型 +├── data //数据集 +├── datasets-cifar10-bin //数据集 +├── mindspore_poly +│ ├── coeffResult //预训练多项式系数 +│ ├── degreeResult //预训练多项式阶数 +│ ├── model +│ │ ├── resnet_cifar10.py //多项式网络定义 +│ │ └── utils_approx.py //多项式计算函数 +│ └── train_resnet20_cifar10.py //模型训练 +├── mindspore_resnet20.ipynb //mindspore框架下标准网络训练 +├── mindspore_resnet20.py //mindspore框架下标准网络训练 +└── pytorch_resnet20.py //pytorch框架下标准网络训练 +``` diff --git a/examples/community/homomorphic_inference/data/keep b/examples/community/homomorphic_inference/data/keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/community/homomorphic_inference/datasets-cifar10-bin/cifar-10-batches-bin/keep b/examples/community/homomorphic_inference/datasets-cifar10-bin/cifar-10-batches-bin/keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_10.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_10.txt new file mode 100644 index 0000000000000000000000000000000000000000..616865dd568eef836d880ce8fc765c3c91393aad --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_10.txt @@ -0,0 +1,30 @@ +-0.16804881224859701754e-46 +10.854184257744249796 +0.51921340560426110074e-45 +-62.28339252110988471 +-0.16735871500743852985e-44 +114.36922782044335657 +0.11543707669236311121e-44 +-62.80234969730743326 +0.78625356248397092829e-38 +4.1397617098511193064 +-0.71824174164994055658e-37 +-5.8499764021167980684 +0.51787863444278289581e-37 +2.9437625565928026676 +-0.93305974396004940898e-38 +-0.45453043746015205795 +0.37537415358329215971e-38 +3.2995673904373325138 +-0.10453714002088916108e-36 +-7.8422726029135582971 +0.41864789598423106147e-36 +12.890776411556469438 +-0.60951015954085533252e-36 +-12.491711258448624139 +0.40547544124712444853e-36 +6.9416799142807483492 +-0.12677008781584870523e-36 +-2.0429806739994296629 +0.15245219740063652538e-37 +0.24640713892603125782 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_11.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_11.txt new file mode 100644 index 0000000000000000000000000000000000000000..2c241242595ef908aa331f7af0143e78dbfacac7 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_11.txt @@ -0,0 +1,44 @@ +-0.75773973997740627597e-30 +11.259066740295440722 +0.35029873643910904994e-28 +-65.469293332997393075 +-0.11093952997692154569e-27 +120.69463427775764571 +0.75910235759400087266e-28 +-66.40196953778250699 +0.6707459348521799376e-48 +4.704776242108835538 +-0.35608959461554233042e-47 +-6.7988485159668168237 +0.14359564606992449774e-47 +3.3152510438287307671 +-0.14985342138579284032e-48 +-0.48936293685989724605 +-0.37246666638164325883e-45 +5.3633425765495343159 +0.43873256885377774404e-43 +-35.51695554419623341 +-0.72755862709513560355e-42 +177.80730411564418744 +0.46756335314744358534e-41 +-592.29739541502495217 +-0.1559191149483180791e-40 +1348.9169188936272623 +0.3101987517898984534e-40 +-2158.7644508493821289 +-0.39795227464460684802e-40 +2473.6568558691888017 +0.34425836717395760687e-40 +-2049.1354253624894092 +-0.20544736179116039936e-40 +1227.3931709055906409 +0.84969444743862631574e-41 +-525.82617513400260298 +-0.23985636500344662033e-41 +156.93055871284079877 +0.44222346859769547418e-42 +-30.965859564591263308 +-0.48133142981936186693e-43 +3.6289400081496844467 +0.23518137057881136703e-44 +-0.1911602837499393819 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_12.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_12.txt new file mode 100644 index 0000000000000000000000000000000000000000..0227ed2a13b0c0ade44d1c7e086c61affae31544 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_12.txt @@ -0,0 +1,52 @@ +0.32867125379815857291e-44 +11.552304235722389133 +-0.29234255228681799398e-42 +-67.779451344096825362 +0.933659553243619424e-42 +125.28374040456207308 +-0.64131951207618758946e-42 +-69.014290823293429674 +0.64108738894863390065e-45 +9.6516763618162694555 +-0.12282232950603783088e-42 +-61.693917453846963018 +0.62062456634083572074e-42 +155.17035165229809557 +-0.996218491919333066e-42 +-182.69758238321474774 +0.72774896827061079047e-42 +112.91072652540604621 +-0.26913492484561434183e-42 +-37.775241177026371927 +0.49323508874283516076e-43 +6.475039097323445253 +-0.35687682645890630524e-44 +-0.44561336572336147215 +0.47771057631279131586e-46 +5.2588835557174597154 +-0.29457192143837593098e-44 +-33.723359379428401924 +0.44327940113287919025e-43 +164.98308501345742927 +-0.27686398551955233933e-42 +-541.40889140699277564 +0.91518100299426399978e-42 +1222.9620799796357356 +-0.18207112804794013148e-41 +-1952.0191056647939988 +0.23485275845578151281e-41 +2240.8402137830087186 +-0.20516930050320509213e-41 +-1866.3491698317060853 +0.12415752822080049282e-41 +1127.2211784312163129 +-0.52285077977730429736e-42 +-488.07047463838013451 +0.15092324971381413086e-42 +147.49784630892093329 +-0.28576304536464352374e-43 +-29.517104887952670173 +0.32076854965476087405e-44 +3.5126952093099453428 +-0.16226398549339504967e-45 +-0.1881018365578797706 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_13.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_13.txt new file mode 100644 index 0000000000000000000000000000000000000000..43985eb99d0232b81ce9f8a9b0a585eeff315ed6 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_13.txt @@ -0,0 +1,60 @@ +0.13459576929391090569e-32 +24.558941542500461187 +0.48509566723824261626e-31 +-669.66044971689436801 +-0.24454123585384020859e-29 +6672.9984830133931554 +0.18687481194464005187e-28 +-30603.665616389872425 +-0.5762278175772426705e-28 +73188.403298778778129 +0.85368067300925938918e-28 +-94443.321705008449291 +-0.60270147469466762691e-28 +62325.409421254674884 +0.16234284366194031353e-28 +-16494.674411780599848 +0.15326158858563023363e-46 +9.3562563603543978083 +-0.36897212304824964462e-45 +-59.163896393362639749 +0.17425439970330368218e-44 +148.86093062644842385 +-0.32067211000221387429e-44 +-175.8128748785829444 +0.27911573894864588724e-44 +109.11129968595543035 +-0.12259030930610072562e-44 +-36.676883997875556573 +0.26218914255796237778e-45 +6.3184629031129413078 +-0.21666232642127535753e-46 +-0.43711341508217764519 +0.6435519383199838375e-47 +5.078135697588612878 +0.81260103885576212533e-45 +-30.732991813718681529 +-0.16019847467842701065e-43 +144.10974681280942417 +0.10746315446051181804e-42 +-459.66168882614256179 +-0.36344872304451237262e-42 +1021.520644704596761 +0.72520712536978486691e-42 +-1620.5625670887702504 +-0.92730639785365506188e-42 +1864.6764641657026581 +0.79584309735406509106e-42 +-1567.4930087714349494 +-0.46919010314752753297e-42 +960.9703090934222369 +0.19086334965401618657e-42 +-424.32616187164667827 +-0.52743967802069637614e-43 +131.27850925600366538 +0.94704493797478696798e-44 +-26.9812576626115819 +-0.99818156176375019347e-45 +3.3065138731556502914 +0.46939046619219983164e-46 +-0.18274294462753398785 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_14.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_14.txt new file mode 100644 index 0000000000000000000000000000000000000000..bd64f92a50d2ef058f977bb2d86d4ff7adb9b5b3 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_14.txt @@ -0,0 +1,74 @@ +-0.33857228343349222771e-46 +24.905214319375455558 +0.76706429670786537935e-44 +-682.38305758243000329 +-0.13331852725885950186e-42 +6809.4284539059995199 +0.91946456800204317828e-42 +-31250.710001710598132 +-0.3025478830899497108e-41 +74765.938836375718651 +0.50242602757177002413e-41 +-96504.683847583939449 +-0.40593124032144374049e-41 +63697.792377824620844 +0.12667142782789733501e-41 +-16860.262134719013661 +-0.9279917569679915957e-45 +16.828551192601130179 +0.83240811468667181169e-43 +-339.81175049565943051 +-0.12775656662581185272e-41 +2790.6999879384769844 +0.77015283672913125341e-41 +-11351.415157379078 +-0.24115991880599060476e-40 +26623.001028374524701 +0.44880705621387405238e-40 +-39384.032866197584799 +-0.53482162297220285674e-40 +38788.423034806038289 +0.42572250279855945685e-40 +-26239.530384498866697 +-0.23114662426334729952e-40 +12365.620701653234636 +0.85857146353371829812e-41 +-4053.3646008999918361 +-0.21456494030125522953e-41 +906.04288095108740556 +0.34480336789999280563e-42 +-131.68764920828800121 +-0.32171705933660227767e-43 +11.2176079033623702 +0.13242560040344363915e-44 +-0.42493802046747120845 +0.67287496871653099859e-47 +5.3175549768939158259 +0.56819927580108680432e-45 +-35.437153153157791187 +-0.1351878131554547236e-43 +184.12244132914048722 +0.10553176628958931735e-42 +-655.38683014625320828 +-0.41426651887176019808e-42 +1638.7833542806082222 +0.96309736116631665994e-42 +-2953.8623704822645685 +-0.14455668840936067085e-41 +3908.064233624186516 +0.14726501386448502925e-41 +-3834.9673916513135076 +-0.1047282511696152942e-41 +2799.6065476651709549 +0.5261087287862766673e-42 +-1512.862318866923807 +-0.1860839022225468362e-42 +596.16013934000983037 +0.45364411019946816092e-43 +-166.32173930295850321 +-0.72578228765531381604e-44 +31.098836973988486605 +0.68580052063448574557e-45 +-3.4934937450619072513 +-0.28984981120663795748e-46 +0.17814215695649592896 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_4.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_4.txt new file mode 100644 index 0000000000000000000000000000000000000000..35f74eb975ea0d5bfa2f14b71515292db8adce3f --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_4.txt @@ -0,0 +1,6 @@ +-0.13308328571661910614e-45 +1.7030952550152622128 +0.4282303400683935347e-45 +-0.7119179308972252429 +-0.95007671461412588685e-46 +0.10588838088858641527 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_5.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_5.txt new file mode 100644 index 0000000000000000000000000000000000000000..b48dbd5cdf0bad09c2ceeb781bfd2d5254abb7f7 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_5.txt @@ -0,0 +1,14 @@ +-0.6557526550356444622e-47 +3.8161318912761891715 +0.16383497014687546643e-45 +-9.5345775055636790739 +-0.56775355546317866411e-45 +13.127053865564633214 +0.69572170343273392104e-45 +-9.2492590908629339246 +-0.3754881050914470711e-45 +3.4159401718320714855 +0.91355542332979866396e-46 +-0.62965100676194727464 +-0.81802023536366356411e-47 +0.045667675284501014142 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_6.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_6.txt new file mode 100644 index 0000000000000000000000000000000000000000..be7505497ba5e4a7c101d730786bd7b5958c1891 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_6.txt @@ -0,0 +1,12 @@ +0.93844775458124077663e-48 +3.5011230698579893093 +0.22777858120903902345e-48 +-2.9653846348507701015 +0.43185878462241557434e-46 +2.5008586578047593635 +-0.2840296669129957945e-45 +-2.835698539325111711 +0.30424586148848378787e-45 +1.6508795809820303949 +-0.81872372642115644225e-46 +-0.3394796746275417859 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_7.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_7.txt new file mode 100644 index 0000000000000000000000000000000000000000..e374a801651781a5ad4ff73faec0ace224dbe6f4 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_7.txt @@ -0,0 +1,16 @@ +0.36047157227556087246e-35 +7.3044516495825141119 +-0.50547170420272221469e-34 +-34.682587110865950932 +0.11656466540909508014e-33 +59.859651829882618102 +-0.65429849283953137725e-34 +-31.875522590646616717 +-0.94649140234426094697e-48 +2.4008565221759781728 +0.64174463272534238015e-47 +-2.631254542617839592 +-0.72533856467681473674e-47 +1.5491267477359321766 +0.20691646642181211267e-47 +-0.33117295650430441081 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_8.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_8.txt new file mode 100644 index 0000000000000000000000000000000000000000..a02fdeaa8a3afe3e4807336f7a6fc7cb9d7e1cf6 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_8.txt @@ -0,0 +1,24 @@ +-0.53048975658957883424e-47 +8.8313307202241685663 +0.21584189100655274133e-45 +-46.457503989551297435 +-0.65893788882613692879e-45 +83.028223472040820362 +0.44320580215223931131e-45 +-44.992847782807097824 +-0.33945728644711295237e-31 +3.9488188508326321993 +0.87774430828590316545e-30 +-12.910301099228299329 +-0.37335685270661501373e-29 +28.086536217465829971 +0.55927380858844763427e-29 +-35.596914896513755454 +-0.33696337530707381012e-29 +26.515937088133732328 +0.5368136791487787625e-30 +-11.418488936844970829 +0.19108101768442763358e-30 +2.6255844388133481275 +-0.55068698294223090738e-31 +-0.24917229999864296671 diff --git a/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_9.txt b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_9.txt new file mode 100644 index 0000000000000000000000000000000000000000..5f9d660e351d356c9226ca3b3e6b2b927f35928f --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/coeffResult/coeff_9.txt @@ -0,0 +1,32 @@ +0.38516974123418350064e-43 +18.096628571880738457 +-0.45973041691637794167e-41 +-434.03870327488615935 +0.79629916037569032555e-40 +4154.9710354569631624 +-0.52897711039631686342e-39 +-18684.694361314924706 +0.16721955114891783141e-38 +44165.71778893298258 +-0.26977742479850633046e-38 +-56552.792898340197704 +0.21412459138356943247e-38 +37115.61227257818608 +-0.66172245592719822948e-39 +-9782.4193389278198711 +-0.10450107406385446313e-45 +3.7975332336085665816 +0.42284220981801636912e-44 +-11.77181577711924821 +-0.22557111393663989842e-43 +24.977108667834693744 +0.44246287510686210342e-43 +-31.523884160399347606 +-0.41355419441164577394e-43 +23.729486312672242226 +0.20006015878309407409e-43 +-10.433180019592321979 +-0.48604113271279609616e-44 +2.4674397626083860086 +0.47125621405204993786e-45 +-0.24213010024761749988 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_10.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_10.txt new file mode 100644 index 0000000000000000000000000000000000000000..4fd8abf006dd2653ee9a35d4be74acedfb1df62e --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_10.txt @@ -0,0 +1,3 @@ +7 +7 +13 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_11.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_11.txt new file mode 100644 index 0000000000000000000000000000000000000000..54bbd27b921974172038afb2b08e9b99fc3606dc --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_11.txt @@ -0,0 +1,3 @@ +7 +7 +27 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_12.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_12.txt new file mode 100644 index 0000000000000000000000000000000000000000..5dffe35a3300b8209e5e231447c9dcc9ded3e6af --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_12.txt @@ -0,0 +1,3 @@ +7 +15 +27 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_13.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_13.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f379df03a1c3068230833a5785d4edad120ae54 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_13.txt @@ -0,0 +1,3 @@ +15 +15 +27 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_14.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_14.txt new file mode 100644 index 0000000000000000000000000000000000000000..5237cd51c64383682452535557fc8133503ea18e --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_14.txt @@ -0,0 +1,3 @@ +15 +27 +29 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_4.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_4.txt new file mode 100644 index 0000000000000000000000000000000000000000..7ed6ff82de6bcc2a78243fc9c54d3ef5ac14da69 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_4.txt @@ -0,0 +1 @@ +5 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_5.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_5.txt new file mode 100644 index 0000000000000000000000000000000000000000..b1bd38b62a0800a4f6a80c34e21c5acffae52c7e --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_5.txt @@ -0,0 +1 @@ +13 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_6.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_6.txt new file mode 100644 index 0000000000000000000000000000000000000000..dde5d5d0173bff70bf273c1450c977fdfc17d982 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_6.txt @@ -0,0 +1,2 @@ +3 +7 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_7.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_7.txt new file mode 100644 index 0000000000000000000000000000000000000000..49019db807899bc5793047943ce0fbb1a09b2e14 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_7.txt @@ -0,0 +1,2 @@ +7 +7 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_8.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_8.txt new file mode 100644 index 0000000000000000000000000000000000000000..12b65161666836e4cdaf4dba95501c9eb184aa12 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_8.txt @@ -0,0 +1,2 @@ +7 +15 diff --git a/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_9.txt b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_9.txt new file mode 100644 index 0000000000000000000000000000000000000000..d2de525e8598cc86ffc221bd5689677d5fc94b7f --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/degreeResult/deg_9.txt @@ -0,0 +1,2 @@ +15 +15 diff --git a/examples/community/homomorphic_inference/mindspore_poly/train_resnet20_cifar10.py b/examples/community/homomorphic_inference/mindspore_poly/train_resnet20_cifar10.py new file mode 100644 index 0000000000000000000000000000000000000000..122c6c2ea7e9cd4657689fcf6b3d6a8c7f0e522c --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_poly/train_resnet20_cifar10.py @@ -0,0 +1,202 @@ +from __future__ import print_function +from tqdm import * + +import sys +import argparse +import mindspore +import mindspore.nn as nn + + +from model.resnet_cifar10 import * +from model.utils_approx import rangeException + +import mindspore as ms +import mindspore.dataset as ds +import mindspore.dataset.vision as vision +import mindspore.dataset.transforms as transforms +from mindspore import dtype as mstype +parser = argparse.ArgumentParser(description='Implementation of of Section V-A for `Precise Approximation of Convolutional Neural' + + 'Networks for Homomorphically Encrypted Data.`') +parser.add_argument('--approx_method', default='proposed', dest='approx_method', type=str, + help='Method of approximating non-arithmetic operations. `proposed`: proposed composition of minimax polynomials, '\ + '`square`: approximate ReLU as x^2, `relu_aq`: approximate ReLU as 2^-3*x^2+2^-1*x+2^-2. '\ + 'For `square` and `relu_aq`, we use exact max-pooling function.') + +parser.add_argument('--alpha', default=14, dest='alpha', type=int, + help='The precision parameter. Integers from 4 to 14 can be used.') +parser.add_argument('--B_relu', default=50.0, dest='B_relu', type=float, + help='The bound of approximation range for the approximate ReLU function.') +parser.add_argument('--B_max', default=50.0, dest='B_max', type=float, + help='The bound of approximation range for the approximate max-pooling function.') +args = parser.parse_args() + +data_dir = "../datasets-cifar10-bin/cifar-10-batches-bin" # 数据集根目录 +batch_size = 256 # 批量大小 +image_size = 32 # 训练图像空间大小 +workers = 4 # 并行线程个数 +num_classes = 10 # 分类数量 + + +def create_dataset_cifar10(dataset_dir, usage, resize, batch_size, workers): + + data_set = ds.Cifar10Dataset(dataset_dir=dataset_dir, + usage=usage, + num_parallel_workers=workers, + shuffle=True) + + trans = [] + if usage == "train": + trans += [ + vision.RandomCrop((32, 32), (4, 4, 4, 4)), + vision.RandomHorizontalFlip(prob=0.5) + ] + + trans += [ + vision.Resize(resize), + vision.Rescale(1.0 / 255.0, 0.0), + vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), + vision.HWC2CHW() + ] + + target_trans = transforms.TypeCast(mstype.int32) + + # 数据映射操作 + data_set = data_set.map(operations=trans, + input_columns='image', + num_parallel_workers=workers) + + data_set = data_set.map(operations=target_trans, + input_columns='label', + num_parallel_workers=workers) + + # 批量操作 + data_set = data_set.batch(batch_size) + + return data_set + + +# 获取处理后的训练与测试数据集 + +dataset_train = create_dataset_cifar10(dataset_dir=data_dir, + usage="train", + resize=image_size, + batch_size=batch_size, + workers=workers) +step_size_train = dataset_train.get_dataset_size() + +dataset_val = create_dataset_cifar10(dataset_dir=data_dir, + usage="test", + resize=image_size, + batch_size=batch_size, + workers=workers) +step_size_val = dataset_val.get_dataset_size() + +approx_dict_list = [{'alpha': args.alpha, 'B': args.B_relu, 'type': args.approx_method}, + {'alpha': args.alpha, 'B': args.B_max, 'type': args.approx_method}] +# 定义ResNet20网络 +network = resnet20(pretrained=False,approx_param_dict_list=approx_dict_list) +print(network) +# 全连接层输入层的大小 +in_channel = network.fc.in_channels +fc = nn.Dense(in_channels=in_channel, out_channels=10) +# 重置全连接层 +network.fc = fc + +# 设置学习率 +num_epochs = 50 +lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs, + step_per_epoch=step_size_train, decay_epoch=num_epochs) +# 定义优化器和损失函数 +opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9) +loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + +def forward_fn(inputs, targets): + logits = network(inputs) + loss = loss_fn(logits, targets) + return loss + + +grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters) + + +def train_step(inputs, targets): + # print(inputs.shape) + # print(targets.shape) + loss, grads = grad_fn(inputs, targets) + opt(grads) + return loss + +import os + +# 创建迭代器 +data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs) +data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs) + +# 最佳模型存储路径 +best_acc = 0 +best_ckpt_dir = "../BestCheckpoint" +best_ckpt_path = "../BestCheckpoint/resnet20-best-poly.ckpt" + +if not os.path.exists(best_ckpt_dir): + os.mkdir(best_ckpt_dir) + +import mindspore.ops as ops + + +def train(data_loader, epoch): + """模型训练""" + losses = [] + network.set_train(True) + + for i, (images, labels) in enumerate(data_loader): + # print(images.shape) + # print(labels.shape) + loss = train_step(images, labels) + if i % 100 == 0 or i == step_size_train - 1: + print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' % + (epoch + 1, num_epochs, i + 1, step_size_train, loss)) + losses.append(loss) + + return sum(losses) / len(losses) + + +def evaluate(data_loader): + """模型验证""" + network.set_train(False) + + correct_num = 0.0 # 预测正确个数 + total_num = 0.0 # 预测总数 + + for images, labels in data_loader: + logits = network(images) + pred = logits.argmax(axis=1) # 预测结果 + correct = ops.equal(pred, labels).reshape((-1, )) + correct_num += correct.sum().asnumpy() + total_num += correct.shape[0] + + acc = correct_num / total_num # 准确率 + + return acc + +# 开始循环训练 +print("Start Training Loop ...") + +for epoch in range(num_epochs): + curr_loss = train(data_loader_train, epoch) + curr_acc = evaluate(data_loader_val) + + print("-" * 50) + print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % ( + epoch+1, num_epochs, curr_loss, curr_acc + )) + print("-" * 50) + + # 保存当前预测准确率最高的模型 + if curr_acc > best_acc: + best_acc = curr_acc + ms.save_checkpoint(network, best_ckpt_path) + +print("=" * 80) +print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, " + f"save the best ckpt file in {best_ckpt_path}", flush=True) \ No newline at end of file diff --git a/examples/community/homomorphic_inference/mindspore_quick_start.ipynb b/examples/community/homomorphic_inference/mindspore_quick_start.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..0aa775aaa411f1a6af04460ab54d09b8ad059c7b --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_quick_start.ipynb @@ -0,0 +1,589 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/zh_cn/beginner/mindspore_quick_start.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/zh_cn/beginner/mindspore_quick_start.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/beginner/quick_start.ipynb)\n", + "\n", + "[基本介绍](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/introduction.html) || **快速入门** || [张量 Tensor](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/tensor.html) || [数据加载与处理](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/dataset.html) || [网络构建](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/model.html) || [函数式自动微分](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/autograd.html) || [模型训练](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/train.html) || [保存与加载](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/save_load.html) || [使用静态图加速](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/accelerate_with_static_graph.html) || [自动混合精度](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/mixed_precision.html) ||" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 快速入门\n", + "\n", + "本节通过MindSpore的API来快速实现一个简单的深度学习模型。若想要深入了解MindSpore的使用方法,请参阅各节最后提供的参考链接。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import mindspore\n", + "from mindspore import nn\n", + "from mindspore.dataset import vision, transforms\n", + "from mindspore.dataset import MnistDataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 处理数据集\n", + "\n", + "MindSpore提供基于Pipeline的[数据引擎](https://www.mindspore.cn/docs/zh-CN/master/design/data_engine.html),通过[数据集(Dataset)](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/dataset.html)实现高效的数据预处理。在本教程中,我们使用Mnist数据集,自动下载完成后,使用`mindspore.dataset`提供的数据变换进行预处理。\n", + "\n", + "> 本章节中的示例代码依赖`download`,可使用命令`pip install download`安装。如本文档以Notebook运行时,完成安装后需要重启kernel才能执行后续代码。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip (10.3 MB)\n", + "\n", + "file_sizes: 100%|██████████████████████████| 10.8M/10.8M [00:01<00:00, 6.73MB/s]\n", + "Extracting zip file...\n", + "Successfully downloaded / unzipped to ./\n" + ] + } + ], + "source": [ + "# Download data from open datasets\n", + "from download import download\n", + "\n", + "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/\" \\\n", + " \"notebook/datasets/MNIST_Data.zip\"\n", + "path = download(url, \"./\", kind=\"zip\", replace=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "MNIST数据集目录结构如下:\n", + "\n", + "```text\n", + "MNIST_Data\n", + "└── train\n", + " ├── train-images-idx3-ubyte (60000个训练图片)\n", + " ├── train-labels-idx1-ubyte (60000个训练标签)\n", + "└── test\n", + " ├── t10k-images-idx3-ubyte (10000个测试图片)\n", + " ├── t10k-labels-idx1-ubyte (10000个测试标签)\n", + "\n", + "```\n", + "\n", + "数据下载完成后,获得数据集对象。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = MnistDataset('MNIST_Data/train')\n", + "test_dataset = MnistDataset('MNIST_Data/test')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "打印数据集中包含的数据列名,用于dataset的预处理。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['image', 'label']\n" + ] + } + ], + "source": [ + "print(train_dataset.get_col_names())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "MindSpore的dataset使用数据处理流水线(Data Processing Pipeline),需指定map、batch、shuffle等操作。这里我们使用map对图像数据及标签进行变换处理,将输入的图像缩放为1/255,根据均值0.1307和标准差值0.3081进行归一化处理,然后将处理好的数据集打包为大小为64的batch。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "def datapipe(dataset, batch_size):\n", + " image_transforms = [\n", + " vision.Rescale(1.0 / 255.0, 0),\n", + " vision.Normalize(mean=(0.1307,), std=(0.3081,)),\n", + " vision.HWC2CHW()\n", + " ]\n", + " label_transform = transforms.TypeCast(mindspore.int32)\n", + "\n", + " dataset = dataset.map(image_transforms, 'image')\n", + " dataset = dataset.map(label_transform, 'label')\n", + " dataset = dataset.batch(batch_size)\n", + " return dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Map vision transforms and batch dataset\n", + "train_dataset = datapipe(train_dataset, 64)\n", + "test_dataset = datapipe(test_dataset, 64)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可使用[create_tuple_iterator](https://www.mindspore.cn/docs/zh-CN/master/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_tuple_iterator.html) 或[create_dict_iterator](https://www.mindspore.cn/docs/zh-CN/master/api_python/dataset/dataset_method/iterator/mindspore.dataset.Dataset.create_dict_iterator.html)对数据集进行迭代访问,查看数据和标签的shape和datatype。" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32\n", + "Shape of label: (64,) Int32\n" + ] + } + ], + "source": [ + "for image, label in test_dataset.create_tuple_iterator():\n", + " print(f\"Shape of image [N, C, H, W]: {image.shape} {image.dtype}\")\n", + " print(f\"Shape of label: {label.shape} {label.dtype}\")\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Shape of image [N, C, H, W]: (64, 1, 28, 28) Float32\n", + "Shape of label: (64,) Int32\n" + ] + } + ], + "source": [ + "for data in test_dataset.create_dict_iterator():\n", + " print(f\"Shape of image [N, C, H, W]: {data['image'].shape} {data['image'].dtype}\")\n", + " print(f\"Shape of label: {data['label'].shape} {data['label'].dtype}\")\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "更多细节详见[数据加载与处理](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/dataset.html)。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 网络构建\n", + "\n", + "`mindspore.nn`类是构建所有网络的基类,也是网络的基本单元。当用户需要自定义网络时,可以继承`nn.Cell`类,并重写`__init__`方法和`construct`方法。`__init__`包含所有网络层的定义,`construct`中包含数据([Tensor](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/tensor.html))的变换过程。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Network<\n", + " (flatten): Flatten<>\n", + " (dense_relu_sequential): SequentialCell<\n", + " (0): Dense\n", + " (1): ReLU<>\n", + " (2): Dense\n", + " (3): ReLU<>\n", + " (4): Dense\n", + " >\n", + " >\n" + ] + } + ], + "source": [ + "# Define model\n", + "class Network(nn.Cell):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.flatten = nn.Flatten()\n", + " self.dense_relu_sequential = nn.SequentialCell(\n", + " nn.Dense(28*28, 512),\n", + " nn.ReLU(),\n", + " nn.Dense(512, 512),\n", + " nn.ReLU(),\n", + " nn.Dense(512, 10)\n", + " )\n", + "\n", + " def construct(self, x):\n", + " x = self.flatten(x)\n", + " logits = self.dense_relu_sequential(x)\n", + " return logits\n", + "\n", + "model = Network()\n", + "print(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "更多细节详见[网络构建](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/model.html)。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 模型训练" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在模型训练中,一个完整的训练过程(step)需要实现以下三步:\n", + "\n", + "1. **正向计算**:模型预测结果(logits),并与正确标签(label)求预测损失(loss)。\n", + "2. **反向传播**:利用自动微分机制,自动求模型参数(parameters)对于loss的梯度(gradients)。\n", + "3. **参数优化**:将梯度更新到参数上。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "MindSpore使用函数式自动微分机制,因此针对上述步骤需要实现:\n", + "\n", + "1. 定义正向计算函数。\n", + "2. 使用[value_and_grad](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore/mindspore.value_and_grad.html)通过函数变换获得梯度计算函数。\n", + "3. 定义训练函数,使用[set_train](https://www.mindspore.cn/docs/zh-CN/master/api_python/nn/mindspore.nn.Cell.html#mindspore.nn.Cell.set_train)设置为训练模式,执行正向计算、反向传播和参数优化。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Instantiate loss function and optimizer\n", + "loss_fn = nn.CrossEntropyLoss()\n", + "optimizer = nn.SGD(model.trainable_params(), 1e-2)\n", + "\n", + "# 1. Define forward function\n", + "def forward_fn(data, label):\n", + " logits = model(data)\n", + " loss = loss_fn(logits, label)\n", + " return loss, logits\n", + "\n", + "# 2. Get gradient function\n", + "grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)\n", + "\n", + "# 3. Define function of one-step training\n", + "def train_step(data, label):\n", + " (loss, _), grads = grad_fn(data, label)\n", + " optimizer(grads)\n", + " return loss\n", + "\n", + "def train(model, dataset):\n", + " size = dataset.get_dataset_size()\n", + " model.set_train()\n", + " for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):\n", + " loss = train_step(data, label)\n", + "\n", + " if batch % 100 == 0:\n", + " loss, current = loss.asnumpy(), batch\n", + " print(f\"loss: {loss:>7f} [{current:>3d}/{size:>3d}]\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "除训练外,我们定义测试函数,用来评估模型的性能。" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def test(model, dataset, loss_fn):\n", + " num_batches = dataset.get_dataset_size()\n", + " model.set_train(False)\n", + " total, test_loss, correct = 0, 0, 0\n", + " for data, label in dataset.create_tuple_iterator():\n", + " pred = model(data)\n", + " total += len(data)\n", + " test_loss += loss_fn(pred, label).asnumpy()\n", + " correct += (pred.argmax(1) == label).asnumpy().sum()\n", + " test_loss /= num_batches\n", + " correct /= total\n", + " print(f\"Test: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "训练过程需多次迭代数据集,一次完整的迭代称为一轮(epoch)。在每一轮,遍历训练集进行训练,结束后使用测试集进行预测。打印每一轮的loss值和预测准确率(Accuracy),可以看到loss在不断下降,Accuracy在不断提高。" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1\n", + "-------------------------------\n", + "loss: 2.302088 [ 0/938]\n", + "loss: 2.290692 [100/938]\n", + "loss: 2.266338 [200/938]\n", + "loss: 2.205240 [300/938]\n", + "loss: 1.907198 [400/938]\n", + "loss: 1.455603 [500/938]\n", + "loss: 0.861103 [600/938]\n", + "loss: 0.767219 [700/938]\n", + "loss: 0.422253 [800/938]\n", + "loss: 0.513922 [900/938]\n", + "Test: \n", + " Accuracy: 83.8%, Avg loss: 0.529534 \n", + "\n", + "Epoch 2\n", + "-------------------------------\n", + "loss: 0.580867 [ 0/938]\n", + "loss: 0.479347 [100/938]\n", + "loss: 0.677991 [200/938]\n", + "loss: 0.550141 [300/938]\n", + "loss: 0.226565 [400/938]\n", + "loss: 0.314738 [500/938]\n", + "loss: 0.298739 [600/938]\n", + "loss: 0.459540 [700/938]\n", + "loss: 0.332978 [800/938]\n", + "loss: 0.406709 [900/938]\n", + "Test: \n", + " Accuracy: 90.2%, Avg loss: 0.334828 \n", + "\n", + "Epoch 3\n", + "-------------------------------\n", + "loss: 0.461890 [ 0/938]\n", + "loss: 0.242303 [100/938]\n", + "loss: 0.281414 [200/938]\n", + "loss: 0.207835 [300/938]\n", + "loss: 0.206000 [400/938]\n", + "loss: 0.409646 [500/938]\n", + "loss: 0.193608 [600/938]\n", + "loss: 0.217575 [700/938]\n", + "loss: 0.212817 [800/938]\n", + "loss: 0.202862 [900/938]\n", + "Test: \n", + " Accuracy: 91.9%, Avg loss: 0.280962 \n", + "\n", + "Done!\n" + ] + } + ], + "source": [ + "epochs = 3\n", + "for t in range(epochs):\n", + " print(f\"Epoch {t+1}\\n-------------------------------\")\n", + " train(model, train_dataset)\n", + " test(model, test_dataset, loss_fn)\n", + "print(\"Done!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "更多细节详见[模型训练](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/train.html)。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 保存模型\n", + "\n", + "模型训练完成后,需要将其参数进行保存。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved Model to model.ckpt\n" + ] + } + ], + "source": [ + "# Save checkpoint\n", + "mindspore.save_checkpoint(model, \"model.ckpt\")\n", + "print(\"Saved Model to model.ckpt\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 加载模型" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "加载保存的权重分为两步:\n", + "\n", + "1. 重新实例化模型对象,构造模型。\n", + "2. 加载模型参数,并将其加载至模型上。" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[]\n" + ] + } + ], + "source": [ + "# Instantiate a random initialized model\n", + "model = Network()\n", + "# Load checkpoint and load parameter to model\n", + "param_dict = mindspore.load_checkpoint(\"model.ckpt\")\n", + "param_not_load, _ = mindspore.load_param_into_net(model, param_dict)\n", + "print(param_not_load)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "> `param_not_load`是未被加载的参数列表,为空时代表所有参数均加载成功。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "加载后的模型可以直接用于预测推理。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Predicted: \"[3 9 6 1 6 7 4 5 2 2]\", Actual: \"[3 9 6 1 6 7 4 5 2 2]\"\n" + ] + } + ], + "source": [ + "model.set_train(False)\n", + "for data, label in test_dataset:\n", + " pred = model(data)\n", + " predicted = pred.argmax(1)\n", + " print(f'Predicted: \"{predicted[:10]}\", Actual: \"{label[:10]}\"')\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "更多细节详见[保存与加载](https://www.mindspore.cn/tutorials/zh-CN/master/beginner/save_load.html)。" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "MindSpore", + "language": "python", + "name": "mindspore" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.5 (default, Oct 25 2019, 15:51:11) \n[GCC 7.3.0]" + }, + "vscode": { + "interpreter": { + "hash": "8c9da313289c39257cb28b126d2dadd33153d4da4d524f730c81a4aaccbd2ca7" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/community/homomorphic_inference/mindspore_resnet20.ipynb b/examples/community/homomorphic_inference/mindspore_resnet20.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..cf0f13d535db0d4f5874504214aa2a62c00ff8c4 --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_resnet20.ipynb @@ -0,0 +1,726 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a987ee48", + "metadata": {}, + "source": [ + "## 数据集准备与加载\n", + "\n", + "[CIFAR-10数据集](http://www.cs.toronto.edu/~kriz/cifar.html)共有60000张32*32的彩色图像,分为10个类别,每类有6000张图,数据集一共有50000张训练图片和10000张评估图片。首先,如下示例使用`download`接口下载并解压,目前仅支持解析二进制版本的CIFAR-10文件(CIFAR-10 binary version)。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f9b81fb", + "metadata": {}, + "outputs": [], + "source": [ + "# from download import download\n", + "\n", + "# url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz\"\n", + "\n", + "# download(url, \"./datasets-cifar10-bin\", kind=\"tar.gz\", replace=True)" + ] + }, + { + "cell_type": "markdown", + "id": "7e9020ba", + "metadata": {}, + "source": [ + "下载后的数据集目录结构如下:\n", + "\n", + "```text\n", + "datasets-cifar10-bin/cifar-10-batches-bin\n", + "├── batches.meta.text\n", + "├── data_batch_1.bin\n", + "├── data_batch_2.bin\n", + "├── data_batch_3.bin\n", + "├── data_batch_4.bin\n", + "├── data_batch_5.bin\n", + "├── readme.html\n", + "└── test_batch.bin\n", + "\n", + "```\n", + "\n", + "然后,使用`mindspore.dataset.Cifar10Dataset`接口来加载数据集,并进行相关图像增强操作。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "df7fb621", + "metadata": {}, + "outputs": [], + "source": [ + "import mindspore as ms\n", + "import mindspore.dataset as ds\n", + "import mindspore.dataset.vision as vision\n", + "import mindspore.dataset.transforms as transforms\n", + "from mindspore import dtype as mstype\n", + "\n", + "data_dir = \"./datasets-cifar10-bin/cifar-10-batches-bin\" # 数据集根目录\n", + "batch_size = 256 # 批量大小\n", + "image_size = 32 # 训练图像空间大小\n", + "workers = 4 # 并行线程个数\n", + "num_classes = 10 # 分类数量\n", + "\n", + "\n", + "def create_dataset_cifar10(dataset_dir, usage, resize, batch_size, workers):\n", + "\n", + " data_set = ds.Cifar10Dataset(dataset_dir=dataset_dir,\n", + " usage=usage,\n", + " num_parallel_workers=workers,\n", + " shuffle=True)\n", + "\n", + " trans = []\n", + " if usage == \"train\":\n", + " trans += [\n", + " vision.RandomCrop((32, 32), (4, 4, 4, 4)),\n", + " vision.RandomHorizontalFlip(prob=0.5)\n", + " ]\n", + "\n", + " trans += [\n", + " vision.Resize(resize),\n", + " vision.Rescale(1.0 / 255.0, 0.0),\n", + " vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),\n", + " vision.HWC2CHW()\n", + " ]\n", + "\n", + " target_trans = transforms.TypeCast(mstype.int32)\n", + "\n", + " # 数据映射操作\n", + " data_set = data_set.map(operations=trans,\n", + " input_columns='image',\n", + " num_parallel_workers=workers)\n", + "\n", + " data_set = data_set.map(operations=target_trans,\n", + " input_columns='label',\n", + " num_parallel_workers=workers)\n", + "\n", + " # 批量操作\n", + " data_set = data_set.batch(batch_size)\n", + "\n", + " return data_set\n", + "\n", + "\n", + "# 获取处理后的训练与测试数据集\n", + "\n", + "dataset_train = create_dataset_cifar10(dataset_dir=data_dir,\n", + " usage=\"train\",\n", + " resize=image_size,\n", + " batch_size=batch_size,\n", + " workers=workers)\n", + "step_size_train = dataset_train.get_dataset_size()\n", + "\n", + "dataset_val = create_dataset_cifar10(dataset_dir=data_dir,\n", + " usage=\"test\",\n", + " resize=image_size,\n", + " batch_size=batch_size,\n", + " workers=workers)\n", + "step_size_val = dataset_val.get_dataset_size()" + ] + }, + { + "cell_type": "markdown", + "id": "21e86f95", + "metadata": {}, + "source": [ + "对CIFAR-10训练数据集进行可视化。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c3ffabb3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image shape: (256, 3, 32, 32), Label shape: (256,)\n", + "Labels: [4 6 4 1 8 3]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "data_iter = next(dataset_train.create_dict_iterator())\n", + "\n", + "images = data_iter[\"image\"].asnumpy()\n", + "labels = data_iter[\"label\"].asnumpy()\n", + "print(f\"Image shape: {images.shape}, Label shape: {labels.shape}\")\n", + "\n", + "# 训练数据集中,前六张图片所对应的标签\n", + "print(f\"Labels: {labels[:6]}\")\n", + "\n", + "classes = []\n", + "\n", + "with open(data_dir + \"/batches.meta.txt\", \"r\") as f:\n", + " for line in f:\n", + " line = line.rstrip()\n", + " if line:\n", + " classes.append(line)\n", + "\n", + "# 训练数据集的前六张图片\n", + "plt.figure()\n", + "for i in range(6):\n", + " plt.subplot(2, 3, i + 1)\n", + " image_trans = np.transpose(images[i], (1, 2, 0))\n", + " mean = np.array([0.4914, 0.4822, 0.4465])\n", + " std = np.array([0.2023, 0.1994, 0.2010])\n", + " image_trans = std * image_trans + mean\n", + " image_trans = np.clip(image_trans, 0, 1)\n", + " plt.title(f\"{classes[labels[i]]}\")\n", + " plt.imshow(image_trans)\n", + " plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1ebef3d0", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Type, Union, List, Optional\n", + "import mindspore.nn as nn\n", + "from mindspore.common.initializer import Normal\n", + "from mindspore import load_checkpoint, load_param_into_net\n", + "\n", + "# 初始化卷积层与BatchNorm的参数\n", + "weight_init = Normal(mean=0, sigma=0.02)\n", + "gamma_init = Normal(mean=1, sigma=0.02)\n", + "\n", + "class ResidualBlockBase(nn.Cell):\n", + " expansion: int = 1 # 最后一个卷积核数量与第一个卷积核数量相等\n", + "\n", + " def __init__(self, in_channel: int, out_channel: int,\n", + " stride: int = 1, \n", + " down_sample: Optional[nn.Cell] = None) -> None:\n", + " super(ResidualBlockBase, self).__init__()\n", + " # if not norm:\n", + " # self.norm = nn.BatchNorm2d(out_channel)\n", + " # else:\n", + " # self.norm = norm \n", + "\n", + " \n", + " self.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init)\n", + " self.bn1 = nn.BatchNorm2d(out_channel)\n", + " self.relu = nn.ReLU()\n", + " self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, weight_init=weight_init)\n", + " self.bn2 = nn.BatchNorm2d(out_channel)\n", + " # self.down_sample = down_sample\n", + " self.down_sample = nn.SequentialCell()\n", + " if stride != 1 or in_channel != out_channel:\n", + " self.down_sample = nn.SequentialCell(\n", + " nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride),\n", + " nn.BatchNorm2d(out_channel)\n", + " )\n", + "\n", + " def construct(self, x):\n", + " \"\"\"ResidualBlockBase construct.\"\"\"\n", + " identity = x # shortcuts分支\n", + "\n", + " out = self.conv1(x) # 主分支第一层:3*3卷积层\n", + " out = self.bn1(out)\n", + " out = self.relu(out)\n", + " out = self.conv2(out) # 主分支第二层:3*3卷积层\n", + " out = self.bn2(out)\n", + "\n", + " if self.down_sample is not None:\n", + " identity = self.down_sample(x)\n", + " out += identity # 输出为主分支与shortcuts之和\n", + " out = self.relu(out)\n", + "\n", + " return out\n", + "\n", + "\n", + "# def make_layer(last_out_channel, block: Type[ResidualBlockBase],\n", + "# channel: int, block_nums: int, stride: int = 1):\n", + "# # down_sample = None # shortcuts分支\n", + "\n", + "# # if stride != 1 or last_out_channel != channel * block.expansion:\n", + "\n", + "# # down_sample = nn.SequentialCell([\n", + "# # nn.Conv2d(last_out_channel, channel * block.expansion,\n", + "# # kernel_size=1, stride=stride, weight_init=weight_init),\n", + "# # nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)\n", + "# # ])\n", + "\n", + "# # layers = []\n", + "# # layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))\n", + "\n", + "# # in_channel = channel * block.expansion\n", + "# # # 堆叠残差网络\n", + "# # for _ in range(1, block_nums):\n", + "\n", + "# # layers.append(block(in_channel, channel))\n", + "# strides = [stride] + [1] * (block_nums - 1)\n", + "# layers = []\n", + "# for stride in strides:\n", + "# layers.append(block(last_out_channel, channel, stride))\n", + "# last_out_channel = channel * block.expansion\n", + "# return nn.SequentialCell(layers)\n", + "\n", + "class ResNet(nn.Cell):\n", + " def __init__(self, block: Type[ResidualBlockBase],\n", + " layer_nums: List[int], num_classes: int, input_channel: int) -> None:\n", + " super(ResNet, self).__init__()\n", + " self.in_channels = 16\n", + "\n", + " # 第一个卷积层,输入channel为3(彩色图像),输出channel为16\n", + " self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, weight_init=weight_init,pad_mode='pad',padding=1)\n", + " self.norm = nn.BatchNorm2d(16)\n", + " self.relu = nn.ReLU()\n", + " # # 最大池化层,缩小图片的尺寸\n", + " # self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')\n", + " # 各个残差网络结构块定义\n", + " self.layer1 = self.make_layer(block, 16, layer_nums[0],stride=1)\n", + " self.layer2 = self.make_layer(block, 32, layer_nums[1], stride=2)\n", + " self.layer3 = self.make_layer(block, 64, layer_nums[2], stride=2)\n", + " # self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)\n", + " # 平均池化层\n", + " self.avg_pool = nn.AvgPool2d(8,1)\n", + " # self.avg_pool = nn.AvgPool2d(pad_mode='pad',padding=(1,1))\n", + " # flattern层\n", + " self.flatten = nn.Flatten()\n", + " # 全连接层\n", + " self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)\n", + "\n", + " # def make_layer(self, block, out_channels, num_blocks, stride):\n", + " # strides = [stride] + [1] * (num_blocks - 1)\n", + " # layers = []\n", + " # for stride in strides:\n", + " # layers.append(block(self.in_channels, out_channels, stride))\n", + " # self.in_channels = out_channels * block.expansion\n", + " # return nn.SequentialCell(layers)\n", + " def make_layer(self, block, out_channels, num_blocks, stride):\n", + " layers = []\n", + " layers.append(block(self.in_channels, out_channels, stride))\n", + " self.in_channels = out_channels\n", + " for _ in range(1, num_blocks):\n", + " layers.append(block(out_channels, out_channels))\n", + " return nn.SequentialCell(*layers)\n", + " def construct(self, x):\n", + " # print(x.shape)\n", + " x = self.conv1(x)\n", + " # print(x.shape)\n", + " x = self.norm(x)\n", + " # print(x.shape)\n", + " x = self.relu(x)\n", + " # print(x.shape)\n", + " x = self.layer1(x)\n", + " # print(x.shape)\n", + " x = self.layer2(x)\n", + " # print(x.shape)\n", + " x = self.layer3(x)\n", + " # print(x.shape)\n", + " x = self.avg_pool(x)\n", + " # print(x.shape)\n", + " x = self.flatten(x)\n", + " # print(x.shape)\n", + " x = self.fc(x)\n", + "# (256, 3, 32, 32)\n", + "# (256, 16, 32, 32)\n", + "# (256, 16, 32, 32)\n", + "# (256, 16, 32, 32)\n", + "# (256, 16, 32, 32)\n", + "# (256, 32, 16, 16)\n", + "# (256, 64, 8, 8)\n", + "# (256, 64, 1, 1)\n", + "# (256, 64)\n", + " return x\n", + "\n", + "def _resnet(model_url: str, block: Type[ResidualBlockBase],\n", + " layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,\n", + " input_channel: int):\n", + " model = ResNet(block, layers, num_classes, input_channel)\n", + "\n", + " if pretrained:\n", + " # 加载预训练模型\n", + " download(url=model_url, path=pretrained_ckpt, replace=True)\n", + " param_dict = load_checkpoint(pretrained_ckpt)\n", + " load_param_into_net(model, param_dict)\n", + "\n", + " return model\n", + "\n", + "def resnet20(num_classes: int = 10, pretrained: bool = False):\n", + " \"\"\"ResNet20模型\"\"\"\n", + " resnet20_url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt\"\n", + " resnet20_ckpt = \"./LoadPretrainedModel/resnet20_new.ckpt\"\n", + " return _resnet(resnet20_url, ResidualBlockBase, [3, 3, 3], num_classes,\n", + " pretrained, resnet20_ckpt, 64)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9cf10c03", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ResNet<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (layer1): SequentialCell<\n", + " (0): ResidualBlockBase<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (bn1): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (bn2): BatchNorm2d\n", + " (down_sample): SequentialCell<>\n", + " >\n", + " (1): ResidualBlockBase<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (bn1): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (bn2): BatchNorm2d\n", + " (down_sample): SequentialCell<>\n", + " >\n", + " (2): ResidualBlockBase<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (bn1): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (bn2): BatchNorm2d\n", + " (down_sample): SequentialCell<>\n", + " >\n", + " >\n", + " (layer2): SequentialCell<\n", + " (0): ResidualBlockBase<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (bn1): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (bn2): BatchNorm2d\n", + " (down_sample): SequentialCell<\n", + " (0): Conv2d, bias_init=None, format=NCHW>\n", + " (1): BatchNorm2d\n", + " >\n", + " >\n", + " (1): ResidualBlockBase<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (bn1): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (bn2): BatchNorm2d\n", + " (down_sample): SequentialCell<>\n", + " >\n", + " (2): ResidualBlockBase<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (bn1): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (bn2): BatchNorm2d\n", + " (down_sample): SequentialCell<>\n", + " >\n", + " >\n", + " (layer3): SequentialCell<\n", + " (0): ResidualBlockBase<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (bn1): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (bn2): BatchNorm2d\n", + " (down_sample): SequentialCell<\n", + " (0): Conv2d, bias_init=None, format=NCHW>\n", + " (1): BatchNorm2d\n", + " >\n", + " >\n", + " (1): ResidualBlockBase<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (bn1): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (bn2): BatchNorm2d\n", + " (down_sample): SequentialCell<>\n", + " >\n", + " (2): ResidualBlockBase<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (bn1): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (bn2): BatchNorm2d\n", + " (down_sample): SequentialCell<>\n", + " >\n", + " >\n", + " (avg_pool): AvgPool2d\n", + " (flatten): Flatten<>\n", + " (fc): Dense\n", + " >\n" + ] + } + ], + "source": [ + "# 定义ResNet20网络\n", + "network = resnet20(pretrained=False)\n", + "print(network)\n", + "# 全连接层输入层的大小\n", + "in_channel = network.fc.in_channels\n", + "fc = nn.Dense(in_channels=in_channel, out_channels=10)\n", + "# 重置全连接层\n", + "network.fc = fc" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e1c632ff", + "metadata": {}, + "outputs": [], + "source": [ + "# 设置学习率\n", + "num_epochs = 50\n", + "lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,\n", + " step_per_epoch=step_size_train, decay_epoch=num_epochs)\n", + "# 定义优化器和损失函数\n", + "opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)\n", + "loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n", + "\n", + "\n", + "def forward_fn(inputs, targets):\n", + " logits = network(inputs)\n", + " loss = loss_fn(logits, targets)\n", + " return loss\n", + "\n", + "\n", + "grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)\n", + "\n", + "\n", + "def train_step(inputs, targets):\n", + " # print(inputs.shape)\n", + " # print(targets.shape)\n", + " loss, grads = grad_fn(inputs, targets)\n", + " opt(grads)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b627e30c", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# 创建迭代器\n", + "data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)\n", + "data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)\n", + "\n", + "# 最佳模型存储路径\n", + "best_acc = 0\n", + "best_ckpt_dir = \"./BestCheckpoint\"\n", + "best_ckpt_path = \"./BestCheckpoint/resnet20-best.ckpt\"\n", + "\n", + "if not os.path.exists(best_ckpt_dir):\n", + " os.mkdir(best_ckpt_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "8a5170df", + "metadata": {}, + "outputs": [], + "source": [ + "import mindspore.ops as ops\n", + "\n", + "\n", + "def train(data_loader, epoch):\n", + " \"\"\"模型训练\"\"\"\n", + " losses = []\n", + " network.set_train(True)\n", + "\n", + " for i, (images, labels) in enumerate(data_loader):\n", + " # print(images.shape)\n", + " # print(labels.shape)\n", + " loss = train_step(images, labels)\n", + " if i % 100 == 0 or i == step_size_train - 1:\n", + " print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' %\n", + " (epoch + 1, num_epochs, i + 1, step_size_train, loss))\n", + " losses.append(loss)\n", + "\n", + " return sum(losses) / len(losses)\n", + "\n", + "\n", + "def evaluate(data_loader):\n", + " \"\"\"模型验证\"\"\"\n", + " network.set_train(False)\n", + "\n", + " correct_num = 0.0 # 预测正确个数\n", + " total_num = 0.0 # 预测总数\n", + "\n", + " for images, labels in data_loader:\n", + " logits = network(images)\n", + " pred = logits.argmax(axis=1) # 预测结果\n", + " correct = ops.equal(pred, labels).reshape((-1, ))\n", + " correct_num += correct.sum().asnumpy()\n", + " total_num += correct.shape[0]\n", + "\n", + " acc = correct_num / total_num # 准确率\n", + "\n", + " return acc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "562a04ca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Start Training Loop ...\n", + "Epoch: [ 1/ 50], Steps: [ 1/196], Train Loss: [2.400]\n" + ] + } + ], + "source": [ + "# 开始循环训练\n", + "print(\"Start Training Loop ...\")\n", + "\n", + "for epoch in range(num_epochs):\n", + " curr_loss = train(data_loader_train, epoch)\n", + " curr_acc = evaluate(data_loader_val)\n", + "\n", + " print(\"-\" * 50)\n", + " print(\"Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]\" % (\n", + " epoch+1, num_epochs, curr_loss, curr_acc\n", + " ))\n", + " print(\"-\" * 50)\n", + "\n", + " # 保存当前预测准确率最高的模型\n", + " if curr_acc > best_acc:\n", + " best_acc = curr_acc\n", + " ms.save_checkpoint(network, best_ckpt_path)\n", + "\n", + "print(\"=\" * 80)\n", + "print(f\"End of validation the best Accuracy is: {best_acc: 5.3f}, \"\n", + " f\"save the best ckpt file in {best_ckpt_path}\", flush=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "46e28f6f", + "metadata": {}, + "source": [ + "## 可视化模型预测\n", + "\n", + "定义`visualize_model`函数,使用上述验证精度最高的模型对CIFAR-10测试数据集进行预测,并将预测结果可视化。若预测字体颜色为蓝色表示为预测正确,预测字体颜色为红色则表示预测错误。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ba2fa94", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def visualize_model(best_ckpt_path, dataset_val):\n", + " num_class = 10\n", + " net = resnet20(num_class)\n", + " # 加载模型参数\n", + " param_dict = ms.load_checkpoint(best_ckpt_path)\n", + " ms.load_param_into_net(net, param_dict)\n", + " # 加载验证集的数据进行验证\n", + " data = next(dataset_val.create_dict_iterator())\n", + " images = data[\"image\"]\n", + " labels = data[\"label\"]\n", + " # 预测图像类别\n", + " output = net(data['image'])\n", + " pred = np.argmax(output.asnumpy(), axis=1)\n", + "\n", + " # 图像分类\n", + " classes = []\n", + "\n", + " with open(data_dir + \"/batches.meta.txt\", \"r\") as f:\n", + " for line in f:\n", + " line = line.rstrip()\n", + " if line:\n", + " classes.append(line)\n", + "\n", + " # 显示图像及图像的预测值\n", + " plt.figure()\n", + " for i in range(6):\n", + " plt.subplot(2, 3, i + 1)\n", + " # 若预测正确,显示为蓝色;若预测错误,显示为红色\n", + " color = 'blue' if pred[i] == labels.asnumpy()[i] else 'red'\n", + " plt.title('predict:{}'.format(classes[pred[i]]), color=color)\n", + " picture_show = np.transpose(images.asnumpy()[i], (1, 2, 0))\n", + " mean = np.array([0.4914, 0.4822, 0.4465])\n", + " std = np.array([0.2023, 0.1994, 0.2010])\n", + " picture_show = std * picture_show + mean\n", + " picture_show = np.clip(picture_show, 0, 1)\n", + " plt.imshow(picture_show)\n", + " plt.axis('off')\n", + "\n", + " plt.show()\n", + "\n", + "\n", + "# 使用测试数据集进行验证\n", + "visualize_model(best_ckpt_path=best_ckpt_path, dataset_val=dataset_val)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/community/homomorphic_inference/mindspore_resnet20.py b/examples/community/homomorphic_inference/mindspore_resnet20.py new file mode 100644 index 0000000000000000000000000000000000000000..03cc4ad9e356da25dce9b66a6ceb58648af2182c --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_resnet20.py @@ -0,0 +1,305 @@ +import mindspore as ms +import mindspore.dataset as ds +import mindspore.dataset.vision as vision +import mindspore.dataset.transforms as transforms +from mindspore import dtype as mstype + +data_dir = "./datasets-cifar10-bin/cifar-10-batches-bin" # 数据集根目录 +batch_size = 256 # 批量大小 +image_size = 32 # 训练图像空间大小 +workers = 4 # 并行线程个数 +num_classes = 10 # 分类数量 + + +def create_dataset_cifar10(dataset_dir, usage, resize, batch_size, workers): + + data_set = ds.Cifar10Dataset(dataset_dir=dataset_dir, + usage=usage, + num_parallel_workers=workers, + shuffle=True) + + trans = [] + if usage == "train": + trans += [ + vision.RandomCrop((32, 32), (4, 4, 4, 4)), + vision.RandomHorizontalFlip(prob=0.5) + ] + + trans += [ + vision.Resize(resize), + vision.Rescale(1.0 / 255.0, 0.0), + vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), + vision.HWC2CHW() + ] + + target_trans = transforms.TypeCast(mstype.int32) + + # 数据映射操作 + data_set = data_set.map(operations=trans, + input_columns='image', + num_parallel_workers=workers) + + data_set = data_set.map(operations=target_trans, + input_columns='label', + num_parallel_workers=workers) + + # 批量操作 + data_set = data_set.batch(batch_size) + + return data_set + + +# 获取处理后的训练与测试数据集 + +dataset_train = create_dataset_cifar10(dataset_dir=data_dir, + usage="train", + resize=image_size, + batch_size=batch_size, + workers=workers) +step_size_train = dataset_train.get_dataset_size() + +dataset_val = create_dataset_cifar10(dataset_dir=data_dir, + usage="test", + resize=image_size, + batch_size=batch_size, + workers=workers) +step_size_val = dataset_val.get_dataset_size() + + + +from typing import Type, Union, List, Optional +import mindspore.nn as nn +from mindspore.common.initializer import Normal +from mindspore import load_checkpoint, load_param_into_net + +# 初始化卷积层与BatchNorm的参数 +weight_init = Normal(mean=0, sigma=0.02) +gamma_init = Normal(mean=1, sigma=0.02) + +class ResidualBlockBase(nn.Cell): + expansion: int = 1 # 最后一个卷积核数量与第一个卷积核数量相等 + + def __init__(self, in_channel: int, out_channel: int, + stride: int = 1, + down_sample: Optional[nn.Cell] = None) -> None: + super(ResidualBlockBase, self).__init__() + # if not norm: + # self.norm = nn.BatchNorm2d(out_channel) + # else: + # self.norm = norm + + + self.conv1 = nn.Conv2d(in_channel, out_channel,kernel_size=3, stride=stride,weight_init=weight_init) + self.bn1 = nn.BatchNorm2d(out_channel) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=3, weight_init=weight_init) + self.bn2 = nn.BatchNorm2d(out_channel) + # self.down_sample = down_sample + self.down_sample = nn.SequentialCell() + if stride != 1 or in_channel != out_channel: + self.down_sample = nn.SequentialCell( + nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=stride), + nn.BatchNorm2d(out_channel) + ) + + def construct(self, x): + """ResidualBlockBase construct.""" + identity = x # shortcuts分支 + + out = self.conv1(x) # 主分支第一层:3*3卷积层 + out = self.bn1(out) + out = self.relu(out) + out = self.conv2(out) # 主分支第二层:3*3卷积层 + out = self.bn2(out) + + if self.down_sample is not None: + identity = self.down_sample(x) + out += identity # 输出为主分支与shortcuts之和 + out = self.relu(out) + + return out + + + + +class ResNet(nn.Cell): + def __init__(self, block: Type[ResidualBlockBase], + layer_nums: List[int], num_classes: int, input_channel: int) -> None: + super(ResNet, self).__init__() + self.in_channels = 16 + + # 第一个卷积层,输入channel为3(彩色图像),输出channel为16 + self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, weight_init=weight_init,pad_mode='pad',padding=1) + self.norm = nn.BatchNorm2d(16) + self.relu = nn.ReLU() + # # 最大池化层,缩小图片的尺寸 + # self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') + # 各个残差网络结构块定义 + self.layer1 = self.make_layer(block, 16, layer_nums[0],stride=1) + self.layer2 = self.make_layer(block, 32, layer_nums[1], stride=2) + self.layer3 = self.make_layer(block, 64, layer_nums[2], stride=2) + # self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2) + # 平均池化层 + self.avg_pool = nn.AvgPool2d(8,1) + # self.avg_pool = nn.AvgPool2d(pad_mode='pad',padding=(1,1)) + # flattern层 + self.flatten = nn.Flatten() + # 全连接层 + self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes) + + + def make_layer(self, block, out_channels, num_blocks, stride): + layers = [] + layers.append(block(self.in_channels, out_channels, stride)) + self.in_channels = out_channels + for _ in range(1, num_blocks): + layers.append(block(out_channels, out_channels)) + return nn.SequentialCell(*layers) + def construct(self, x): + # print(x.shape) + x = self.conv1(x) + # print(x.shape) + x = self.norm(x) + # print(x.shape) + x = self.relu(x) + # print(x.shape) + x = self.layer1(x) + # print(x.shape) + x = self.layer2(x) + # print(x.shape) + x = self.layer3(x) + # print(x.shape) + x = self.avg_pool(x) + # print(x.shape) + x = self.flatten(x) + # print(x.shape) + x = self.fc(x) + return x + +def _resnet(model_url: str, block: Type[ResidualBlockBase], + layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str, + input_channel: int): + model = ResNet(block, layers, num_classes, input_channel) + + if pretrained: + # 加载预训练模型 + download(url=model_url, path=pretrained_ckpt, replace=True) + param_dict = load_checkpoint(pretrained_ckpt) + load_param_into_net(model, param_dict) + + return model + +def resnet20(num_classes: int = 10, pretrained: bool = False): + """ResNet20模型""" + resnet20_url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt" + resnet20_ckpt = "./LoadPretrainedModel/resnet20_new.ckpt" + return _resnet(resnet20_url, ResidualBlockBase, [3, 3, 3], num_classes, + pretrained, resnet20_ckpt, 64) + +# 定义ResNet20网络 +network = resnet20(pretrained=False) +print(network) +# 全连接层输入层的大小 +in_channel = network.fc.in_channels +fc = nn.Dense(in_channels=in_channel, out_channels=10) +# 重置全连接层 +network.fc = fc + +# 设置学习率 +num_epochs = 50 +lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs, + step_per_epoch=step_size_train, decay_epoch=num_epochs) +# 定义优化器和损失函数 +opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9) +loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + +def forward_fn(inputs, targets): + logits = network(inputs) + loss = loss_fn(logits, targets) + return loss + + +grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters) + + +def train_step(inputs, targets): + # print(inputs.shape) + # print(targets.shape) + loss, grads = grad_fn(inputs, targets) + opt(grads) + return loss + +import os + +# 创建迭代器 +data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs) +data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs) + +# 最佳模型存储路径 +best_acc = 0 +best_ckpt_dir = "./BestCheckpoint" +best_ckpt_path = "./BestCheckpoint/resnet20-best.ckpt" + +if not os.path.exists(best_ckpt_dir): + os.mkdir(best_ckpt_dir) + +import mindspore.ops as ops + + +def train(data_loader, epoch): + """模型训练""" + losses = [] + network.set_train(True) + + for i, (images, labels) in enumerate(data_loader): + # print(images.shape) + # print(labels.shape) + loss = train_step(images, labels) + if i % 100 == 0 or i == step_size_train - 1: + print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' % + (epoch + 1, num_epochs, i + 1, step_size_train, loss)) + losses.append(loss) + + return sum(losses) / len(losses) + + +def evaluate(data_loader): + """模型验证""" + network.set_train(False) + + correct_num = 0.0 # 预测正确个数 + total_num = 0.0 # 预测总数 + + for images, labels in data_loader: + logits = network(images) + pred = logits.argmax(axis=1) # 预测结果 + correct = ops.equal(pred, labels).reshape((-1, )) + correct_num += correct.sum().asnumpy() + total_num += correct.shape[0] + + acc = correct_num / total_num # 准确率 + + return acc + +# 开始循环训练 +print("Start Training Loop ...") + +for epoch in range(num_epochs): + curr_loss = train(data_loader_train, epoch) + curr_acc = evaluate(data_loader_val) + + print("-" * 50) + print("Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]" % ( + epoch+1, num_epochs, curr_loss, curr_acc + )) + print("-" * 50) + + # 保存当前预测准确率最高的模型 + if curr_acc > best_acc: + best_acc = curr_acc + ms.save_checkpoint(network, best_ckpt_path) + +print("=" * 80) +print(f"End of validation the best Accuracy is: {best_acc: 5.3f}, " + f"save the best ckpt file in {best_ckpt_path}", flush=True) \ No newline at end of file diff --git a/examples/community/homomorphic_inference/mindspore_resnet50.ipynb b/examples/community/homomorphic_inference/mindspore_resnet50.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..12aed0fe8e29ea7cbd3e9551c0c87e103250f01a --- /dev/null +++ b/examples/community/homomorphic_inference/mindspore_resnet50.ipynb @@ -0,0 +1,1032 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fa7e3e52", + "metadata": {}, + "source": [ + "[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/zh_cn/cv/mindspore_resnet50.ipynb) [![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/master/tutorials/zh_cn/cv/mindspore_resnet50.py) [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/master/tutorials/source_zh_cn/cv/resnet50.ipynb)\n", + "\n", + "# ResNet50图像分类\n", + "\n", + "图像分类是最基础的计算机视觉应用,属于有监督学习类别,如给定一张图像(猫、狗、飞机、汽车等等),判断图像所属的类别。本章将介绍使用ResNet50网络对CIFAR-10数据集进行分类。\n", + "\n", + "## ResNet网络介绍\n", + "\n", + "ResNet50网络是2015年由微软实验室的何恺明提出,获得ILSVRC2015图像分类竞赛第一名。在ResNet网络提出之前,传统的卷积神经网络都是将一系列的卷积层和池化层堆叠得到的,但当网络堆叠到一定深度时,就会出现退化问题。下图是在CIFAR-10数据集上使用56层网络与20层网络训练误差和测试误差图,由图中数据可以看出,56层网络比20层网络训练误差和测试误差更大,随着网络的加深,其误差并没有如预想的一样减小。\n", + "\n", + "![resnet-1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/resnet_1.png)\n", + "\n", + "ResNet网络提出了残差网络结构(Residual Network)来减轻退化问题,使用ResNet网络可以实现搭建较深的网络结构(突破1000层)。论文中使用ResNet网络在CIFAR-10数据集上的训练误差与测试误差图如下图所示,图中虚线表示训练误差,实线表示测试误差。由图中数据可以看出,ResNet网络层数越深,其训练误差和测试误差越小。\n", + "\n", + "![resnet-4](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/resnet_4.png)\n", + "\n", + "> 了解ResNet网络更多详细内容,参见[ResNet论文](https://arxiv.org/pdf/1512.03385.pdf)。" + ] + }, + { + "cell_type": "markdown", + "id": "a987ee48", + "metadata": {}, + "source": [ + "## 数据集准备与加载\n", + "\n", + "[CIFAR-10数据集](http://www.cs.toronto.edu/~kriz/cifar.html)共有60000张32*32的彩色图像,分为10个类别,每类有6000张图,数据集一共有50000张训练图片和10000张评估图片。首先,如下示例使用`download`接口下载并解压,目前仅支持解析二进制版本的CIFAR-10文件(CIFAR-10 binary version)。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f9b81fb", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating data folder...\n", + "Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz (162.2 MB)\n", + "\n", + "file_sizes: 100%|████████████████████████████| 170M/170M [00:17<00:00, 9.76MB/s]\n", + "Extracting tar.gz file...\n", + "Successfully downloaded / unzipped to ./datasets-cifar10-bin\n" + ] + }, + { + "data": { + "text/plain": [ + "'./datasets-cifar10-bin'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from download import download\n", + "\n", + "url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/cifar-10-binary.tar.gz\"\n", + "\n", + "download(url, \"./datasets-cifar10-bin\", kind=\"tar.gz\", replace=True)" + ] + }, + { + "cell_type": "markdown", + "id": "7e9020ba", + "metadata": {}, + "source": [ + "下载后的数据集目录结构如下:\n", + "\n", + "```text\n", + "datasets-cifar10-bin/cifar-10-batches-bin\n", + "├── batches.meta.text\n", + "├── data_batch_1.bin\n", + "├── data_batch_2.bin\n", + "├── data_batch_3.bin\n", + "├── data_batch_4.bin\n", + "├── data_batch_5.bin\n", + "├── readme.html\n", + "└── test_batch.bin\n", + "\n", + "```\n", + "\n", + "然后,使用`mindspore.dataset.Cifar10Dataset`接口来加载数据集,并进行相关图像增强操作。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "df7fb621", + "metadata": {}, + "outputs": [], + "source": [ + "import mindspore as ms\n", + "import mindspore.dataset as ds\n", + "import mindspore.dataset.vision as vision\n", + "import mindspore.dataset.transforms as transforms\n", + "from mindspore import dtype as mstype\n", + "\n", + "data_dir = \"./datasets-cifar10-bin/cifar-10-batches-bin\" # 数据集根目录\n", + "batch_size = 256 # 批量大小\n", + "image_size = 32 # 训练图像空间大小\n", + "workers = 4 # 并行线程个数\n", + "num_classes = 10 # 分类数量\n", + "\n", + "\n", + "def create_dataset_cifar10(dataset_dir, usage, resize, batch_size, workers):\n", + "\n", + " data_set = ds.Cifar10Dataset(dataset_dir=dataset_dir,\n", + " usage=usage,\n", + " num_parallel_workers=workers,\n", + " shuffle=True)\n", + "\n", + " trans = []\n", + " if usage == \"train\":\n", + " trans += [\n", + " vision.RandomCrop((32, 32), (4, 4, 4, 4)),\n", + " vision.RandomHorizontalFlip(prob=0.5)\n", + " ]\n", + "\n", + " trans += [\n", + " vision.Resize(resize),\n", + " vision.Rescale(1.0 / 255.0, 0.0),\n", + " vision.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]),\n", + " vision.HWC2CHW()\n", + " ]\n", + "\n", + " target_trans = transforms.TypeCast(mstype.int32)\n", + "\n", + " # 数据映射操作\n", + " data_set = data_set.map(operations=trans,\n", + " input_columns='image',\n", + " num_parallel_workers=workers)\n", + "\n", + " data_set = data_set.map(operations=target_trans,\n", + " input_columns='label',\n", + " num_parallel_workers=workers)\n", + "\n", + " # 批量操作\n", + " data_set = data_set.batch(batch_size)\n", + "\n", + " return data_set\n", + "\n", + "\n", + "# 获取处理后的训练与测试数据集\n", + "\n", + "dataset_train = create_dataset_cifar10(dataset_dir=data_dir,\n", + " usage=\"train\",\n", + " resize=image_size,\n", + " batch_size=batch_size,\n", + " workers=workers)\n", + "step_size_train = dataset_train.get_dataset_size()\n", + "\n", + "dataset_val = create_dataset_cifar10(dataset_dir=data_dir,\n", + " usage=\"test\",\n", + " resize=image_size,\n", + " batch_size=batch_size,\n", + " workers=workers)\n", + "step_size_val = dataset_val.get_dataset_size()" + ] + }, + { + "cell_type": "markdown", + "id": "21e86f95", + "metadata": {}, + "source": [ + "对CIFAR-10训练数据集进行可视化。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c3ffabb3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Image shape: (256, 3, 32, 32), Label shape: (256,)\n", + "Labels: [1 3 5 3 8 6]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "\n", + "data_iter = next(dataset_train.create_dict_iterator())\n", + "\n", + "images = data_iter[\"image\"].asnumpy()\n", + "labels = data_iter[\"label\"].asnumpy()\n", + "print(f\"Image shape: {images.shape}, Label shape: {labels.shape}\")\n", + "\n", + "# 训练数据集中,前六张图片所对应的标签\n", + "print(f\"Labels: {labels[:6]}\")\n", + "\n", + "classes = []\n", + "\n", + "with open(data_dir + \"/batches.meta.txt\", \"r\") as f:\n", + " for line in f:\n", + " line = line.rstrip()\n", + " if line:\n", + " classes.append(line)\n", + "\n", + "# 训练数据集的前六张图片\n", + "plt.figure()\n", + "for i in range(6):\n", + " plt.subplot(2, 3, i + 1)\n", + " image_trans = np.transpose(images[i], (1, 2, 0))\n", + " mean = np.array([0.4914, 0.4822, 0.4465])\n", + " std = np.array([0.2023, 0.1994, 0.2010])\n", + " image_trans = std * image_trans + mean\n", + " image_trans = np.clip(image_trans, 0, 1)\n", + " plt.title(f\"{classes[labels[i]]}\")\n", + " plt.imshow(image_trans)\n", + " plt.axis(\"off\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "76c96f76", + "metadata": {}, + "source": [ + "## 构建网络\n", + "\n", + "残差网络结构(Residual Network)是ResNet网络的主要亮点,ResNet使用残差网络结构后可有效地减轻退化问题,实现更深的网络结构设计,提高网络的训练精度。本节首先讲述如何构建残差网络结构,然后通过堆叠残差网络来构建ResNet50网络。\n", + "\n", + "### 构建残差网络结构\n", + "\n", + "残差网络结构图如下图所示,残差网络由两个分支构成:一个主分支,一个shortcuts(图中弧线表示)。主分支通过堆叠一系列的卷积操作得到,shortcuts从输入直接到输出,主分支输出的特征矩阵$F(x)$加上shortcuts输出的特征矩阵$x$得到$F(x)+x$,通过Relu激活函数后即为残差网络最后的输出。\n", + "\n", + "![residual](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/resnet_3.png)\n", + "\n", + "残差网络结构主要由两种,一种是Building Block,适用于较浅的ResNet网络,如ResNet18和ResNet34;另一种是Bottleneck,适用于层数较深的ResNet网络,如ResNet50、ResNet101和ResNet152。\n", + "\n", + "#### Building Block\n", + "\n", + "Building Block结构图如下图所示,主分支有两层卷积网络结构:\n", + "\n", + "+ 主分支第一层网络以输入channel为64为例,首先通过一个$3\\times3$的卷积层,然后通过Batch Normalization层,最后通过Relu激活函数层,输出channel为64;\n", + "+ 主分支第二层网络的输入channel为64,首先通过一个$3\\times3$的卷积层,然后通过Batch Normalization层,输出channel为64。\n", + "\n", + "最后将主分支输出的特征矩阵与shortcuts输出的特征矩阵相加,通过Relu激活函数即为Building Block最后的输出。\n", + "\n", + "![building-block-5](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/resnet_5.png)\n", + "\n", + "主分支与shortcuts输出的特征矩阵相加时,需要保证主分支与shortcuts输出的特征矩阵shape相同。如果主分支与shortcuts输出的特征矩阵shape不相同,如输出channel是输入channel的一倍时,shortcuts上需要使用数量与输出channel相等,大小为$1\\times1$的卷积核进行卷积操作;若输出的图像较输入图像缩小一倍,则要设置shortcuts中卷积操作中的`stride`为2,主分支第一层卷积操作的`stride`也需设置为2。\n", + "\n", + "如下代码定义`ResidualBlockBase`类实现Building Block结构。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c7ac0e2d", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Type, Union, List, Optional\n", + "import mindspore.nn as nn\n", + "from mindspore.common.initializer import Normal\n", + "\n", + "# 初始化卷积层与BatchNorm的参数\n", + "weight_init = Normal(mean=0, sigma=0.02)\n", + "gamma_init = Normal(mean=1, sigma=0.02)\n", + "\n", + "class ResidualBlockBase(nn.Cell):\n", + " expansion: int = 1 # 最后一个卷积核数量与第一个卷积核数量相等\n", + "\n", + " def __init__(self, in_channel: int, out_channel: int,\n", + " stride: int = 1, norm: Optional[nn.Cell] = None,\n", + " down_sample: Optional[nn.Cell] = None) -> None:\n", + " super(ResidualBlockBase, self).__init__()\n", + " if not norm:\n", + " self.norm = nn.BatchNorm2d(out_channel)\n", + " else:\n", + " self.norm = norm\n", + "\n", + " self.conv1 = nn.Conv2d(in_channel, out_channel,\n", + " kernel_size=3, stride=stride,\n", + " weight_init=weight_init)\n", + " self.conv2 = nn.Conv2d(in_channel, out_channel,\n", + " kernel_size=3, weight_init=weight_init)\n", + " self.relu = nn.ReLU()\n", + " self.down_sample = down_sample\n", + "\n", + " def construct(self, x):\n", + " \"\"\"ResidualBlockBase construct.\"\"\"\n", + " identity = x # shortcuts分支\n", + "\n", + " out = self.conv1(x) # 主分支第一层:3*3卷积层\n", + " out = self.norm(out)\n", + " out = self.relu(out)\n", + " out = self.conv2(out) # 主分支第二层:3*3卷积层\n", + " out = self.norm(out)\n", + "\n", + " if self.down_sample is not None:\n", + " identity = self.down_sample(x)\n", + " out += identity # 输出为主分支与shortcuts之和\n", + " out = self.relu(out)\n", + "\n", + " return out" + ] + }, + { + "cell_type": "markdown", + "id": "aaa15d3c", + "metadata": {}, + "source": [ + "#### Bottleneck\n", + "\n", + "Bottleneck结构图如下图所示,在输入相同的情况下Bottleneck结构相对Building Block结构的参数数量更少,更适合层数较深的网络,ResNet50使用的残差结构就是Bottleneck。该结构的主分支有三层卷积结构,分别为$1\\times1$的卷积层、$3\\times3$卷积层和$1\\times1$的卷积层,其中$1\\times1$的卷积层分别起降维和升维的作用。\n", + "\n", + "+ 主分支第一层网络以输入channel为256为例,首先通过数量为64,大小为$1\\times1$的卷积核进行降维,然后通过Batch Normalization层,最后通过Relu激活函数层,其输出channel为64;\n", + "+ 主分支第二层网络通过数量为64,大小为$3\\times3$的卷积核提取特征,然后通过Batch Normalization层,最后通过Relu激活函数层,其输出channel为64;\n", + "+ 主分支第三层通过数量为256,大小$1\\times1$的卷积核进行升维,然后通过Batch Normalization层,其输出channel为256。\n", + "\n", + "最后将主分支输出的特征矩阵与shortcuts输出的特征矩阵相加,通过Relu激活函数即为Bottleneck最后的输出。\n", + "\n", + "![building-block-6](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/resnet_6.png)\n", + "\n", + "主分支与shortcuts输出的特征矩阵相加时,需要保证主分支与shortcuts输出的特征矩阵shape相同。如果主分支与shortcuts输出的特征矩阵shape不相同,如输出channel是输入channel的一倍时,shortcuts上需要使用数量与输出channel相等,大小为$1\\times1$的卷积核进行卷积操作;若输出的图像较输入图像缩小一倍,则要设置shortcuts中卷积操作中的`stride`为2,主分支第二层卷积操作的`stride`也需设置为2。\n", + "\n", + "如下代码定义`ResidualBlock`类实现Bottleneck结构。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0d46f98e", + "metadata": {}, + "outputs": [], + "source": [ + "class ResidualBlock(nn.Cell):\n", + " expansion = 4 # 最后一个卷积核的数量是第一个卷积核数量的4倍\n", + "\n", + " def __init__(self, in_channel: int, out_channel: int,\n", + " stride: int = 1, down_sample: Optional[nn.Cell] = None) -> None:\n", + " super(ResidualBlock, self).__init__()\n", + "\n", + " self.conv1 = nn.Conv2d(in_channel, out_channel,\n", + " kernel_size=1, weight_init=weight_init)\n", + " self.norm1 = nn.BatchNorm2d(out_channel)\n", + " self.conv2 = nn.Conv2d(out_channel, out_channel,\n", + " kernel_size=3, stride=stride,\n", + " weight_init=weight_init)\n", + " self.norm2 = nn.BatchNorm2d(out_channel)\n", + " self.conv3 = nn.Conv2d(out_channel, out_channel * self.expansion,\n", + " kernel_size=1, weight_init=weight_init)\n", + " self.norm3 = nn.BatchNorm2d(out_channel * self.expansion)\n", + "\n", + " self.relu = nn.ReLU()\n", + " self.down_sample = down_sample\n", + "\n", + " def construct(self, x):\n", + "\n", + " identity = x # shortscuts分支\n", + "\n", + " out = self.conv1(x) # 主分支第一层:1*1卷积层\n", + " out = self.norm1(out)\n", + " out = self.relu(out)\n", + " out = self.conv2(out) # 主分支第二层:3*3卷积层\n", + " out = self.norm2(out)\n", + " out = self.relu(out)\n", + " out = self.conv3(out) # 主分支第三层:1*1卷积层\n", + " out = self.norm3(out)\n", + "\n", + " if self.down_sample is not None:\n", + " identity = self.down_sample(x)\n", + "\n", + " out += identity # 输出为主分支与shortcuts之和\n", + " out = self.relu(out)\n", + "\n", + " return out" + ] + }, + { + "cell_type": "markdown", + "id": "d1d8dfc9", + "metadata": {}, + "source": [ + "#### 构建ResNet50网络\n", + "\n", + "ResNet网络层结构如下图所示,以输入彩色图像$224\\times224$为例,首先通过数量64,卷积核大小为$7\\times7$,stride为2的卷积层conv1,该层输出图片大小为$112\\times112$,输出channel为64;然后通过一个$3\\times3$的最大下采样池化层,该层输出图片大小为$56\\times56$,输出channel为64;再堆叠4个残差网络块(conv2_x、conv3_x、conv4_x和conv5_x),此时输出图片大小为$7\\times7$,输出channel为2048;最后通过一个平均池化层、全连接层和softmax,得到分类概率。\n", + "\n", + "![resnet-layer](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/tutorials/source_zh_cn/cv/images/resnet_2.png)\n", + "\n", + "对于每个残差网络块,以ResNet50网络中的conv2_x为例,其由3个Bottleneck结构堆叠而成,每个Bottleneck输入的channel为64,输出channel为256。\n", + "\n", + "如下示例定义`make_layer`实现残差块的构建,其参数如下所示:\n", + "\n", + "+ `last_out_channel`:上一个残差网络输出的通道数。\n", + "+ `block`:残差网络的类别,分别为`ResidualBlockBase`和`ResidualBlock`。\n", + "+ `channel`:残差网络块1*1卷积层的输出通道数\n", + "+ `block_nums`:残差网络块堆叠的个数。\n", + "+ `stride`:卷积移动的步幅。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3dfa40a1", + "metadata": {}, + "outputs": [], + "source": [ + "def make_layer(last_out_channel, block: Type[Union[ResidualBlockBase, ResidualBlock]],\n", + " channel: int, block_nums: int, stride: int = 1):\n", + " down_sample = None # shortcuts分支\n", + "\n", + " if stride != 1 or last_out_channel != channel * block.expansion:\n", + "\n", + " down_sample = nn.SequentialCell([\n", + " nn.Conv2d(last_out_channel, channel * block.expansion,\n", + " kernel_size=1, stride=stride, weight_init=weight_init),\n", + " nn.BatchNorm2d(channel * block.expansion, gamma_init=gamma_init)\n", + " ])\n", + "\n", + " layers = []\n", + " layers.append(block(last_out_channel, channel, stride=stride, down_sample=down_sample))\n", + "\n", + " in_channel = channel * block.expansion\n", + " # 堆叠残差网络\n", + " for _ in range(1, block_nums):\n", + "\n", + " layers.append(block(in_channel, channel))\n", + "\n", + " return nn.SequentialCell(layers)" + ] + }, + { + "cell_type": "markdown", + "id": "67dae353", + "metadata": {}, + "source": [ + "ResNet50网络共有5个卷积结构,一个平均池化层,一个全连接层,以CIFAR-10数据集为例:\n", + "\n", + "+ **conv1**:输入图片大小为$32\\times32$,输入channel为3。首先经过一个卷积核数量为64,卷积核大小为$7\\times7$,stride为2的卷积层;然后通过一个Batch Normalization层;最后通过ReLu激活函数。该层输出feature map大小为$16\\times16$,输出channel为64。\n", + "+ **conv2_x**:输入feature map大小为$16\\times16$,输入channel为64。首先经过一个卷积核大小为$3\\times3$,stride为2的最大下采样池化操作;然后堆叠3个$[1\\times1,64;3\\times3,64;1\\times1,256]$结构的Bottleneck。该层输出feature map大小为$8\\times8$,输出channel为256。\n", + "+ **conv3_x**:输入feature map大小为$8\\times8$,输入channel为256。该层堆叠4个[1×1,128;3×3,128;1×1,512]结构的Bottleneck。该层输出feature map大小为$4\\times4$,输出channel为512。\n", + "+ **conv4_x**:输入feature map大小为$4\\times4$,输入channel为512。该层堆叠6个[1×1,256;3×3,256;1×1,1024]结构的Bottleneck。该层输出feature map大小为$2\\times2$,输出channel为1024。\n", + "+ **conv5_x**:输入feature map大小为$2\\times2$,输入channel为1024。该层堆叠3个[1×1,512;3×3,512;1×1,2048]结构的Bottleneck。该层输出feature map大小为$1\\times1$,输出channel为2048。\n", + "+ **average pool & fc**:输入channel为2048,输出channel为分类的类别数。\n", + "\n", + "如下示例代码实现ResNet50模型的构建,通过用调函数`resnet50`即可构建ResNet50模型,函数`resnet50`参数如下:\n", + "\n", + "+ `num_classes`:分类的类别数,默认类别数为1000。\n", + "+ `pretrained`:下载对应的训练模型,并加载预训练模型中的参数到网络中。" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "1ebef3d0", + "metadata": {}, + "outputs": [], + "source": [ + "from mindspore import load_checkpoint, load_param_into_net\n", + "\n", + "\n", + "class ResNet(nn.Cell):\n", + " def __init__(self, block: Type[Union[ResidualBlockBase, ResidualBlock]],\n", + " layer_nums: List[int], num_classes: int, input_channel: int) -> None:\n", + " super(ResNet, self).__init__()\n", + "\n", + " self.relu = nn.ReLU()\n", + " # 第一个卷积层,输入channel为3(彩色图像),输出channel为64\n", + " self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, weight_init=weight_init)\n", + " self.norm = nn.BatchNorm2d(64)\n", + " # 最大池化层,缩小图片的尺寸\n", + " self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')\n", + " # 各个残差网络结构块定义\n", + " self.layer1 = make_layer(64, block, 64, layer_nums[0])\n", + " self.layer2 = make_layer(64 * block.expansion, block, 128, layer_nums[1], stride=2)\n", + " self.layer3 = make_layer(128 * block.expansion, block, 256, layer_nums[2], stride=2)\n", + " self.layer4 = make_layer(256 * block.expansion, block, 512, layer_nums[3], stride=2)\n", + " # 平均池化层\n", + " self.avg_pool = nn.AvgPool2d()\n", + " # flattern层\n", + " self.flatten = nn.Flatten()\n", + " # 全连接层\n", + " self.fc = nn.Dense(in_channels=input_channel, out_channels=num_classes)\n", + "\n", + " def construct(self, x):\n", + "\n", + " x = self.conv1(x)\n", + " x = self.norm(x)\n", + " x = self.relu(x)\n", + " x = self.max_pool(x)\n", + "\n", + " x = self.layer1(x)\n", + " x = self.layer2(x)\n", + " x = self.layer3(x)\n", + " x = self.layer4(x)\n", + " print(x.shape)\n", + " x = self.avg_pool(x)\n", + " print(x.shape)\n", + " x = self.flatten(x)\n", + " print(x.shape)\n", + " x = self.fc(x)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "d16e658e", + "metadata": {}, + "outputs": [], + "source": [ + "def _resnet(model_url: str, block: Type[Union[ResidualBlockBase, ResidualBlock]],\n", + " layers: List[int], num_classes: int, pretrained: bool, pretrained_ckpt: str,\n", + " input_channel: int):\n", + " model = ResNet(block, layers, num_classes, input_channel)\n", + "\n", + " if pretrained:\n", + " # 加载预训练模型\n", + " download(url=model_url, path=pretrained_ckpt, replace=True)\n", + " param_dict = load_checkpoint(pretrained_ckpt)\n", + " load_param_into_net(model, param_dict)\n", + "\n", + " return model\n", + "\n", + "\n", + "def resnet50(num_classes: int = 1000, pretrained: bool = False):\n", + " \"\"\"ResNet50模型\"\"\"\n", + " resnet50_url = \"https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt\"\n", + " resnet50_ckpt = \"./LoadPretrainedModel/resnet50_224_new.ckpt\"\n", + " return _resnet(resnet50_url, ResidualBlock, [3, 4, 6, 3], num_classes,\n", + " pretrained, resnet50_ckpt, 2048)" + ] + }, + { + "cell_type": "markdown", + "id": "d40bd05a", + "metadata": {}, + "source": [ + "## 模型训练与评估\n", + "\n", + "本节使用[ResNet50预训练模型](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/resnet50_224_new.ckpt)进行微调。调用`resnet50`构造ResNet50模型,并设置`pretrained`参数为True,将会自动下载ResNet50预训练模型,并加载预训练模型中的参数到网络中。然后定义优化器和损失函数,逐个epoch打印训练的损失值和评估精度,并保存评估精度最高的ckpt文件(resnet50-best.ckpt)到当前路径的./BestCheckPoint下。\n", + "\n", + "由于预训练模型全连接层(fc)的输出大小(对应参数`num_classes`)为1000, 为了成功加载预训练权重,我们将模型的全连接输出大小设置为默认的1000。CIFAR10数据集共有10个分类,在使用该数据集进行训练时,需要将加载好预训练权重的模型全连接层输出大小重置为10。\n", + "\n", + "> 此处我们展示了5个epochs的训练过程,如果想要达到理想的训练效果,建议训练80个epochs。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9cf10c03", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ResNet<\n", + " (relu): ReLU<>\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm): BatchNorm2d\n", + " (max_pool): MaxPool2d\n", + " (layer1): SequentialCell<\n", + " (0): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (down_sample): SequentialCell<\n", + " (0): Conv2d, bias_init=None, format=NCHW>\n", + " (1): BatchNorm2d\n", + " >\n", + " >\n", + " (1): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " (2): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " >\n", + " (layer2): SequentialCell<\n", + " (0): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (down_sample): SequentialCell<\n", + " (0): Conv2d, bias_init=None, format=NCHW>\n", + " (1): BatchNorm2d\n", + " >\n", + " >\n", + " (1): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " (2): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " (3): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " >\n", + " (layer3): SequentialCell<\n", + " (0): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (down_sample): SequentialCell<\n", + " (0): Conv2d, bias_init=None, format=NCHW>\n", + " (1): BatchNorm2d\n", + " >\n", + " >\n", + " (1): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " (2): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " (3): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " (4): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " (5): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " >\n", + " (layer4): SequentialCell<\n", + " (0): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " (down_sample): SequentialCell<\n", + " (0): Conv2d, bias_init=None, format=NCHW>\n", + " (1): BatchNorm2d\n", + " >\n", + " >\n", + " (1): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " (2): ResidualBlock<\n", + " (conv1): Conv2d, bias_init=None, format=NCHW>\n", + " (norm1): BatchNorm2d\n", + " (conv2): Conv2d, bias_init=None, format=NCHW>\n", + " (norm2): BatchNorm2d\n", + " (conv3): Conv2d, bias_init=None, format=NCHW>\n", + " (norm3): BatchNorm2d\n", + " (relu): ReLU<>\n", + " >\n", + " >\n", + " (avg_pool): AvgPool2d\n", + " (flatten): Flatten<>\n", + " (fc): Dense\n", + " >\n" + ] + } + ], + "source": [ + "# 定义ResNet50网络\n", + "network = resnet50(pretrained=False)\n", + "\n", + "# 全连接层输入层的大小\n", + "in_channel = network.fc.in_channels\n", + "fc = nn.Dense(in_channels=in_channel, out_channels=10)\n", + "# 重置全连接层\n", + "network.fc = fc\n", + "print(network)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "e1c632ff", + "metadata": {}, + "outputs": [], + "source": [ + "# 设置学习率\n", + "num_epochs = 5\n", + "lr = nn.cosine_decay_lr(min_lr=0.00001, max_lr=0.001, total_step=step_size_train * num_epochs,\n", + " step_per_epoch=step_size_train, decay_epoch=num_epochs)\n", + "# 定义优化器和损失函数\n", + "opt = nn.Momentum(params=network.trainable_params(), learning_rate=lr, momentum=0.9)\n", + "loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')\n", + "\n", + "\n", + "def forward_fn(inputs, targets):\n", + " logits = network(inputs)\n", + " loss = loss_fn(logits, targets)\n", + " return loss\n", + "\n", + "\n", + "grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)\n", + "\n", + "\n", + "def train_step(inputs, targets):\n", + " loss, grads = grad_fn(inputs, targets)\n", + " opt(grads)\n", + " return loss" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b627e30c", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# 创建迭代器\n", + "data_loader_train = dataset_train.create_tuple_iterator(num_epochs=num_epochs)\n", + "data_loader_val = dataset_val.create_tuple_iterator(num_epochs=num_epochs)\n", + "\n", + "# 最佳模型存储路径\n", + "best_acc = 0\n", + "best_ckpt_dir = \"./BestCheckpoint\"\n", + "best_ckpt_path = \"./BestCheckpoint/resnet50-best.ckpt\"\n", + "\n", + "if not os.path.exists(best_ckpt_dir):\n", + " os.mkdir(best_ckpt_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8a5170df", + "metadata": {}, + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31m在当前单元格或上一个单元格中执行代码时 Kernel 崩溃。\n", + "\u001b[1;31m请查看单元格中的代码,以确定故障的可能原因。\n", + "\u001b[1;31m单击此处了解详细信息。\n", + "\u001b[1;31m有关更多详细信息,请查看 Jupyter log。" + ] + } + ], + "source": [ + "import mindspore.ops as ops\n", + "\n", + "\n", + "def train(data_loader, epoch):\n", + " \"\"\"模型训练\"\"\"\n", + " losses = []\n", + " network.set_train(True)\n", + "\n", + " for i, (images, labels) in enumerate(data_loader):\n", + " loss = train_step(images, labels)\n", + " if i % 100 == 0 or i == step_size_train - 1:\n", + " print('Epoch: [%3d/%3d], Steps: [%3d/%3d], Train Loss: [%5.3f]' %\n", + " (epoch + 1, num_epochs, i + 1, step_size_train, loss))\n", + " losses.append(loss)\n", + "\n", + " return sum(losses) / len(losses)\n", + "\n", + "\n", + "def evaluate(data_loader):\n", + " \"\"\"模型验证\"\"\"\n", + " network.set_train(False)\n", + "\n", + " correct_num = 0.0 # 预测正确个数\n", + " total_num = 0.0 # 预测总数\n", + "\n", + " for images, labels in data_loader:\n", + " logits = network(images)\n", + " pred = logits.argmax(axis=1) # 预测结果\n", + " correct = ops.equal(pred, labels).reshape((-1, ))\n", + " correct_num += correct.sum().asnumpy()\n", + " total_num += correct.shape[0]\n", + "\n", + " acc = correct_num / total_num # 准确率\n", + "\n", + " return acc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "562a04ca", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Start Training Loop ...\n", + "(256, 2048, 1, 1)\n", + "(256, 2048, 1, 1)\n", + "(256, 2048)\n", + "Epoch: [ 1/ 5], Steps: [ 1/196], Train Loss: [2.524]\n", + "(256, 2048, 1, 1)\n", + "(256, 2048, 1, 1)\n", + "(256, 2048)\n", + "(256, 2048, 1, 1)\n", + "(256, 2048, 1, 1)\n", + "(256, 2048)\n", + "(256, 2048, 1, 1)\n", + "(256, 2048, 1, 1)\n", + "(256, 2048)\n" + ] + } + ], + "source": [ + "# 开始循环训练\n", + "print(\"Start Training Loop ...\")\n", + "\n", + "for epoch in range(num_epochs):\n", + " curr_loss = train(data_loader_train, epoch)\n", + " curr_acc = evaluate(data_loader_val)\n", + "\n", + " print(\"-\" * 50)\n", + " print(\"Epoch: [%3d/%3d], Average Train Loss: [%5.3f], Accuracy: [%5.3f]\" % (\n", + " epoch+1, num_epochs, curr_loss, curr_acc\n", + " ))\n", + " print(\"-\" * 50)\n", + "\n", + " # 保存当前预测准确率最高的模型\n", + " if curr_acc > best_acc:\n", + " best_acc = curr_acc\n", + " ms.save_checkpoint(network, best_ckpt_path)\n", + "\n", + "print(\"=\" * 80)\n", + "print(f\"End of validation the best Accuracy is: {best_acc: 5.3f}, \"\n", + " f\"save the best ckpt file in {best_ckpt_path}\", flush=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "46e28f6f", + "metadata": {}, + "source": [ + "## 可视化模型预测\n", + "\n", + "定义`visualize_model`函数,使用上述验证精度最高的模型对CIFAR-10测试数据集进行预测,并将预测结果可视化。若预测字体颜色为蓝色表示为预测正确,预测字体颜色为红色则表示预测错误。\n", + "\n", + "> 由上面的结果可知,5个epochs下模型在验证数据集的预测准确率在70%左右,即一般情况下,6张图片中会有2张预测失败。如果想要达到理想的训练效果,建议训练80个epochs。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ba2fa94", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def visualize_model(best_ckpt_path, dataset_val):\n", + " num_class = 10\n", + " net = resnet50(num_class)\n", + " # 加载模型参数\n", + " param_dict = ms.load_checkpoint(best_ckpt_path)\n", + " ms.load_param_into_net(net, param_dict)\n", + " # 加载验证集的数据进行验证\n", + " data = next(dataset_val.create_dict_iterator())\n", + " images = data[\"image\"]\n", + " labels = data[\"label\"]\n", + " # 预测图像类别\n", + " output = net(data['image'])\n", + " pred = np.argmax(output.asnumpy(), axis=1)\n", + "\n", + " # 图像分类\n", + " classes = []\n", + "\n", + " with open(data_dir + \"/batches.meta.txt\", \"r\") as f:\n", + " for line in f:\n", + " line = line.rstrip()\n", + " if line:\n", + " classes.append(line)\n", + "\n", + " # 显示图像及图像的预测值\n", + " plt.figure()\n", + " for i in range(6):\n", + " plt.subplot(2, 3, i + 1)\n", + " # 若预测正确,显示为蓝色;若预测错误,显示为红色\n", + " color = 'blue' if pred[i] == labels.asnumpy()[i] else 'red'\n", + " plt.title('predict:{}'.format(classes[pred[i]]), color=color)\n", + " picture_show = np.transpose(images.asnumpy()[i], (1, 2, 0))\n", + " mean = np.array([0.4914, 0.4822, 0.4465])\n", + " std = np.array([0.2023, 0.1994, 0.2010])\n", + " picture_show = std * picture_show + mean\n", + " picture_show = np.clip(picture_show, 0, 1)\n", + " plt.imshow(picture_show)\n", + " plt.axis('off')\n", + "\n", + " plt.show()\n", + "\n", + "\n", + "# 使用测试数据集进行验证\n", + "visualize_model(best_ckpt_path=best_ckpt_path, dataset_val=dataset_val)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}