1 Star 1 Fork 0

Harry/pytorch_wavelets

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
贡献代码
同步代码
取消
提示: 由于 Git 不支持空文件夾,创建文件夹后会生成空的 .keep 文件
Loading...
README
MIT

2D Wavelet Transforms in Pytorch

build status Documentation Status doi

The full documentation is also available here.

This package provides support for computing the 2D discrete wavelet and the 2d dual-tree complex wavelet transforms, their inverses, and passing gradients through both using pytorch.

The implementation is designed to be used with batches of multichannel images. We use the standard pytorch implementation of having 'NCHW' data format.

We also have added layers to do the 2-D DTCWT based scatternet. This is similar to the Morlet based scatternet in KymatIO, but is roughly 10 times faster.

For citing, please use the DOI for the moment. We may release a paper in due time describing the repo.

New in version 1.2.0

  • Added a DTCWT based ScatterNet
import torch
from pytorch_wavelets import ScatLayer
scat = ScatLayer()
X = torch.randn(10,5,64,64)
# A first order scatternet with 6 orientations and one lowpass channels
# gives 7 times the input channel dimension
Z = scat(X)
print(Z.shape)
>>> torch.Size([10, 35, 32, 32])
# A second order scatternet with 6 orientations and one lowpass channels
# gives 7^2 times the input channel dimension
scat2 = torch.nn.Sequential(ScatLayer(), ScatLayer())
Z = scat2(X)
print(Z.shape)
>>> torch.Size([10, 245, 16, 16])
# We also have a slightly more specialized, but slower, second order scatternet
from pytorch_wavelets import ScatLayerj2
scat2a = ScatLayerj2()
Z = scat2a(X)
print(Z.shape)
>>> torch.Size([10, 245, 16, 16])
# These all of course work with cuda
scat2a.cuda()
Z = scat2a(X.cuda())

New in version 1.1.0

  • Fixed memory problem with dwt
  • Fixed the backend code for the dtcwt calculation - much cleaner now but similar performance
  • Both dtcwt and dwt should be more memory efficient/aware now.
  • Removed need to specify number of scales for DTCWTInverse

New in version 1.0.0

Version 1.0.0 has now added support for separable DWT calculation, and more padding schemes, such as symmetric, zero and periodization.

Also, no longer need to specify the number of channels when creating the wavelet transform classes.

Speed Tests

We compare doing the dtcwt with the python package and doing the dwt with PyWavelets to doing both in pytorch_wavelets, using a GTX1080. The numpy methods were run on a 14 core Xeon Phi machine using intel's parallel python. For the dtwcwt we use the near_sym_a filters for the first scale and the qshift_a filters for subsequent scales. For the dwt we use the db4 filters.

For a fixed input size, but varying the number of scales (from 1 to 4) we have the following speeds (averaged over 5 runs):

For an input size with height and width 512 by 512, we also vary the batch size for a 3 scale transform. The resulting speeds were:

Installation

The easiest way to install pytorch_wavelets is to clone the repo and pip install it. Later versions will be released on PyPi but the docs need to updated first:

$ git clone https://github.com/fbcotter/pytorch_wavelets
$ cd pytorch_wavelets
$ pip install .

(Although the develop command may be more useful if you intend to perform any significant modification to the library.) A test suite is provided so that you may verify the code works on your system:

$ pip install -r tests/requirements.txt
$ pytest tests/

Example Use

For the DWT - note that the highpass output has an extra dimension, in which we stack the (lh, hl, hh) coefficients. Also note that the Yh output has the finest detail coefficients first, and the coarsest last (the opposite to PyWavelets).

import torch
from pytorch_wavelets import DWTForward, DWTInverse
xfm = DWTForward(J=3, wave='db3', mode='zero')
X = torch.randn(10,5,64,64)
Yl, Yh = xfm(X)
print(Yl.shape)
>>> torch.Size([10, 5, 12, 12])
print(Yh[0].shape)
>>> torch.Size([10, 5, 3, 34, 34])
print(Yh[1].shape)
>>> torch.Size([10, 5, 3, 19, 19])
print(Yh[2].shape)
>>> torch.Size([10, 5, 3, 12, 12])
ifm = DWTInverse(wave='db3', mode='zero')
Y = ifm((Yl, Yh))

For the DTCWT:

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b')
X = torch.randn(10,5,64,64)
Yl, Yh = xfm(X)
print(Yl.shape)
>>> torch.Size([10, 5, 16, 16])
print(Yh[0].shape)
>>> torch.Size([10, 5, 6, 32, 32, 2])
print(Yh[1].shape)
>>> torch.Size([10, 5, 6, 16, 16, 2])
print(Yh[2].shape)
>>> torch.Size([10, 5, 6, 8, 8, 2])
ifm = DTCWTInverse(J=3, biort='near_sym_b', qshift='qshift_b')
Y = ifm((Yl, Yh))

Some initial notes:

  • Yh returned is a tuple. There are 2 extra dimensions - the first comes between the channel dimension of the input and the row dimension. This is the 6 orientations of the DTCWT. The second is the final dimension, which is the real an imaginary parts (complex numbers are not native to pytorch)

Running on the GPU

This should come as no surprise to pytorch users. The DWT and DTCWT transforms support cuda calling:

import torch
from pytorch_wavelets import DTCWTForward, DTCWTInverse
xfm = DTCWTForward(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
X = torch.randn(10,5,64,64).cuda()
Yl, Yh = xfm(X)
ifm = DTCWTInverse(J=3, biort='near_sym_b', qshift='qshift_b').cuda()
Y = ifm((Yl, Yh))

The automated tests cannot test the gpu functionality, but do check cpu running. To test whether the repo is working on your gpu, you can download the repo, ensure you have pytorch with cuda enabled (the tests will check to see if torch.cuda.is_available() returns true), and run:

pip install -r tests/requirements.txt
pytest tests/

From the base of the repo.

Backpropagation

It is possible to pass gradients through the forward and backward transforms. All you need to do is ensure that the input to each has the required_grad attribute set to true.

Provenance

Based on the Dual-Tree Complex Wavelet Transform Pack for MATLAB by Nick Kingsbury, Cambridge University. The original README can be found in ORIGINAL_README.txt. This file outlines the conditions of use of the original MATLAB toolbox.

Further information on the DT CWT can be obtained from papers downloadable from my website (given below). The best tutorial is in the 1999 Royal Society Paper. In particular this explains the conversion between 'real' quad-number subimages and pairs of complex subimages. The Q-shift filters are explained in the ICIP 2000 paper and in more detail in the May 2001 paper for the Journal on Applied and Computational Harmonic Analysis.

This code is copyright and is supplied free of charge for research purposes only. In return for supplying the code, all I ask is that, if you use the algorithms, you give due reference to this work in any papers that you write and that you let me know if you find any good applications for the DT CWT. If the applications are good, I would be very interested in collaboration. I accept no liability arising from use of these algorithms.

Nick Kingsbury, Cambridge University, June 2003.

Dr N G Kingsbury, Dept. of Engineering, University of Cambridge, Trumpington St., Cambridge CB2 1PZ, UK., or Trinity College, Cambridge CB2 1TQ, UK. Phone: (0 or +44) 1223 338514 / 332647; Home: 1954 211152; Fax: 1223 338564 / 332662; E-mail: ngk@eng.cam.ac.uk Web home page: http://www.eng.cam.ac.uk/~ngk/

This licence applies to any parts of this library which are novel in comparison to the original DTCWT MATLAB toolbox written by Nick Kingsbury and Cian Shaffrey. See the Provenance section of README.rst file for details on any further restrictions of use. If you wish to use the DTCWT, you should read that license as well. The DWT sections come under this license. MIT License Copyright (c) 2020 Fergal Cotter Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

简介

Pytorch implementation of 2D Discrete Wavelet (DWT) and Dual Tree Complex Wavelet Transforms (DTCWT) and a DTCWT based ScatterNet 展开 收起
MIT
取消

发行版

暂无发行版

贡献者

全部

近期动态

加载更多
不能加载更多了
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/harry56/pytorch_wavelets.git
git@gitee.com:harry56/pytorch_wavelets.git
harry56
pytorch_wavelets
pytorch_wavelets
master

搜索帮助