怎樣克服神經網路訓練中argmax的不可導性?

時間 2021-05-06 18:01:31

1樓:章浩

一句話解釋: 正向傳播就和往常一樣,反向傳播時,將梯度從不可導那個點copy到不可導點的前面的最近乙個可導點。

(請看紅線右端點的梯度,跳過中間的字典模組,直達紅線的左端點)

問題來了

1/梯度鏈條怎麼隔斷不讓他經過字典模組?pytorch有個 detach(), 可以隔斷梯度,梯度就不會進入不可導區域引發編譯器報錯

2/梯度怎麼複製?舉個最簡單例子

quantize

=input+(

quantize

-input).

detach

()# 正向傳播和往常一樣,

# 反向傳播時,detach()這部分梯度為0,quantize和input的梯度相同,

# 即實現將quantize複製給input

# quantize即紅線右端點,input即紅線左端點

參考:[1]. Neural Discrete Representation Learning

2樓:

在三維重建任務中,經常會用到畫素點的深度估計,其主要方式就是找最小代價值對應的深度索引。

通過可微的方式估計:

其中,c_d 表示的深度為d時候對應的代價。近似估計的結果為:

如果,要求 argmax 則將代價值的負號去掉,實現近似估計。

Kendall, A., Martirosyan, H., Dasgupta, S.

, Henry, P., Kennedy, R., Bachrach, A.

, & Bry, A. (2017). End-to-end learning of geometry and context for deep stereo regression.

InProceedings of the IEEE International Conference on Computer Vision(pp. 66-75).

3樓:

其實就是乙個在離散分布中取樣的問題具體可以參考 gumble softmax 這裡貼一篇蘇神的文章漫談重引數:從正態分佈到Gumbel Softmax

4樓:king熊

argmax

我的做法是這樣的,利用帶有溫度係數的softmax模擬,一般只要k值設定的精準時,softmax的結果就是one-hot,這個時候one-hot乘法,就實現了argmax。

argmax-k

這塊做法是一樣的,利用多次溫度係數的softmax模擬。但存在乙個問題,如果top-2的值相等時,出來的softmax是兩個0.5,這個時候使用gumbel-softmax對softamx的結果進行取樣,得到的結果代替softmax結果,這樣就實現了。。

這個問題,我解決了半個月,太費勁了。有啥問題再交流。

5樓:OwlLite

可以對argmax/argmin 這種不可導的操作直接忽視,也就是鎖定:

class

ArgMax

(torch

.autograd

.Function

):@staticmethod

defforward

(ctx

,input

):idx

=torch

.argmax

(input,1

)output

=torch

.zeros_like

(input

)output

.scatter_(1

,idx,1

)return

output

@staticmethod

defbackward

(ctx

,grad_output

):return

grad_output

6樓:

VAE、RL、GAN都是有效的涉及離散取樣方法

前者可以用上面說的Gumbel Softmax Trick,具體怎麼弄可以自己去搜下

7樓:Kanata AKIZUKI

自己來進行下回答,在 AAAI-2017 上的 SeqGAN 一文中出現過相似問題的解決辦法。

這裡由於 seq2seq 的特點,在作為generator的生成器時會導致如下問題:

1、generator 的輸出為概率值,這裡根據概率的取樣是離散的,因而無法在輸入 discriminator 保證loss的反向傳播。

2、generator 在每次生成的單詞中無法對後續序列評價該單詞所產生的影響。

對真實文字進行 one-hot 轉換是不可行的,因為這樣會導致 discriminator 只通過最後一維的資料是否由 0 和 1 構成來進行判定,極易導致訓練崩潰。

目前比較好的解決辦法是使用 policy gradient 來訓練 generator,使用蒙特卡洛取樣用已知概率分布模擬整體產生完整句子,並輸入 discriminator 進行 reward 計算。

另外關於 Gumble softmax, 這類再引數化 trick 雖然將取樣過程移到 bp 步驟外面,但是只是對one-hot的光滑逼近。與此同時我們為了保證輸入維度的一致性,需要將真實文字轉變成 one-hot 的輸入,而對於判決器而言,學習這種階梯性的函式是十分容易的。

換句話說,判決器只需要根據最後一維輸入是否為 0, 1 構成即可,這會導致極大地訓練的不均衡性。

8樓:yyll

可以參考

JimmySuen/integral-human-pose

他們自己實現了乙個可導的argmax(根據heatmap得到響應最高點的座標作為人體姿態點的輸出座標),具體如何實現我沒了解。

9樓:V1xerunt

單就解決argmax的不可導性這一問題而言,可以使用argmax的近似估計:

其中h是經過softmax後的向量,N是h的維度,h(i)是h中第i個值

10樓:

用cross entropy loss

loss=nn.CrossEntropyLoss()

out=torch.randn(batch_size,max_length,vocab_length, requires_grad=True)

target=torch.empty(batch_size,max_length,dtype=torch.long).

random_(vocab_length)

output=loss(out, target)output.backward()

神經網路中訓練和推理有什麼區別?

tohnee 1.推斷 Inference 的網路權值已經固定下來,無後向傳播過程,因此可以 模型固定,可以對計算圖進行優化,還可以輸入輸出大小固定,可以做memory優化 注意 有乙個概念是fine tuning,即訓練好的模型繼續調優,只是在已有的模型做小的改動,本質上仍然是訓練 Trainin...

怎麼選取訓練神經網路時的Batch size

嚮往自由 乙個epoch,使用大batch,訓練時間更短 但收斂不一定比小batch好 解析為什麼同乙個epoch,小batch,收斂更快?原因 相同epoch,小batch的梯度迭代更加頻繁,更有可能找到最優解。因此,不是batch越大越好 做自己 我覺得和隨機性大小有關,批梯度下降本身屬於隨機優...

神經網路中,設計loss function有哪些技巧

DLing 損失函式是神經網路能正常訓練的基礎,所以在日常研究中,損失函式也是大家攻關的熱門,可以供大家學習的太多了,各種技巧也很多,FocalLoss,OHEM 每乙個都很有代表性。其實這麼多損失函式中,我自己感覺還是CenterNet的損失函式 或者是Corner Net 最能讓我眼前一亮,整個...