温馨提示×

温馨提示×

您好,登录后才能下订单哦!

密码登录×
登录注册×
其他方式登录
点击 登录注册 即表示同意《亿速云用户服务条款》

PyTorch梯度下降反向传播实例分析

发布时间:2022-03-09 13:33:43 阅读:214 作者:iii 栏目:开发技术
开发者测试专用服务器限时活动,0元免费领,库存有限,领完即止! 点击查看>>

本文小编为大家详细介绍“PyTorch梯度下降反向传播实例分析”,内容详细,步骤清晰,细节处理妥当,希望这篇“PyTorch梯度下降反向传播实例分析”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。

前言:

反向传播的目的是计算成本函数C对网络中任意w或b的偏导数。一旦我们有了这些偏导数,我们将通过一些常数 α的乘积和该数量相对于成本函数的偏导数来更新网络中的权重和偏差。这是流行的梯度下降算法。而偏导数给出了最大上升的方向。因此,关于反向传播算法,我们继续查看下文。

我们向相反的方向迈出了一小步——最大下降的方向,也就是将我们带到成本函数的局部最小值的方向

如题:

PyTorch梯度下降反向传播实例分析

意思是利用这个二次模型来预测数据,减小损失函数(MSE)的值。

代码如下:

import torch
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"]  =  "TRUE"
# 数据集
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0]
# 权重参数初始值均为1
w = torch.tensor([1.0,1.0,1.0])
w.requires_grad = True    # 需要计算梯度

# 前向传播
def forward(x):
    return w[0]*(x**2)+w[1]*x+w[2]
# 计算损失
def loss(x,y):
    y_pred = forward(x)
    return (y_pred-y) ** 2

# 训练模块
print('predict (before tranining) ',4, forward(4).item())
epoch_list = []
w_list = []
loss_list = []
for epoch in range(1000):
    for x,y in zip(x_data,y_data):
        l = loss(x,y)
        l.backward()        # 后向传播
        print('\tgrad: ',x,y,w.grad.data)
        w.data = w.data - 0.01 * w.grad.data        # 梯度下降
        
        w.grad.data.zero_()    # 梯度清零操作
        
    print('progress: ',epoch,l.item())
    epoch_list.append(epoch)
    w_list.append(w.data)
    loss_list.append(l.item())
print('predict (after tranining) ',4, forward(4).item())

# 绘图
plt.plot(epoch_list,loss_list,'b')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid()
plt.show()

结果如下:

predict (before tranining)  4 21.0
    grad:  1.0 2.0 tensor([2., 2., 2.])
    grad:  2.0 4.0 tensor([22.880011.4400,  5.7200])
    grad:  3.0 6.0 tensor([77.047225.6824,  8.5608])
progress:  0 18.321826934814453
    grad:  1.0 2.0 tensor([-1.1466, -1.1466, -1.1466])
    grad:  2.0 4.0 tensor([-15.5367,  -7.7683,  -3.8842])
    grad:  3.0 6.0 tensor([-30.4322, -10.1441,  -3.3814])
progress:  1 2.858394145965576
    grad:  1.0 2.0 tensor([0.34510.34510.3451])
    grad:  2.0 4.0 tensor([2.42731.21370.6068])
    grad:  3.0 6.0 tensor([19.4499,  6.4833,  2.1611])
progress:  2 1.1675907373428345
    grad:  1.0 2.0 tensor([-0.3224, -0.3224, -0.3224])
    grad:  2.0 4.0 tensor([-5.8458, -2.9229, -1.4614])
    grad:  3.0 6.0 tensor([-3.8829, -1.2943, -0.4314])
progress:  3 0.04653334245085716
    grad:  1.0 2.0 tensor([0.01370.01370.0137])
    grad:  2.0 4.0 tensor([-1.9141, -0.9570, -0.4785])
    grad:  3.0 6.0 tensor([6.85572.28520.7617])
progress:  4 0.14506366848945618
    grad:  1.0 2.0 tensor([-0.1182, -0.1182, -0.1182])
    grad:  2.0 4.0 tensor([-3.6644, -1.8322, -0.9161])
    grad:  3.0 6.0 tensor([1.74550.58180.1939])
progress:  5 0.009403289295732975
    grad:  1.0 2.0 tensor([-0.0333, -0.0333, -0.0333])
    grad:  2.0 4.0 tensor([-2.7739, -1.3869, -0.6935])
    grad:  3.0 6.0 tensor([4.01401.33800.4460])
progress:  6 0.04972923547029495
    grad:  1.0 2.0 tensor([-0.0501, -0.0501, -0.0501])
    grad:  2.0 4.0 tensor([-3.1150, -1.5575, -0.7788])
    grad:  3.0 6.0 tensor([2.85340.95110.3170])
progress:  7 0.025129113346338272
    grad:  1.0 2.0 tensor([-0.0205, -0.0205, -0.0205])
    grad:  2.0 4.0 tensor([-2.8858, -1.4429, -0.7215])
    grad:  3.0 6.0 tensor([3.29241.09750.3658])
progress:  8 0.03345605731010437
    grad:  1.0 2.0 tensor([-0.0134, -0.0134, -0.0134])
    grad:  2.0 4.0 tensor([-2.9247, -1.4623, -0.7312])
    grad:  3.0 6.0 tensor([2.99090.99700.3323])
progress:  9 0.027609655633568764
    grad:  1.0 2.0 tensor([0.00330.00330.0033])
    grad:  2.0 4.0 tensor([-2.8414, -1.4207, -0.7103])
    grad:  3.0 6.0 tensor([3.03771.01260.3375])
progress:  10 0.02848036028444767
    grad:  1.0 2.0 tensor([0.01480.01480.0148])
    grad:  2.0 4.0 tensor([-2.8174, -1.4087, -0.7043])
    grad:  3.0 6.0 tensor([2.92600.97530.3251])
progress:  11 0.02642466314136982
    grad:  1.0 2.0 tensor([0.02800.02800.0280])
    grad:  2.0 4.0 tensor([-2.7682, -1.3841, -0.6920])
    grad:  3.0 6.0 tensor([2.89150.96380.3213])
progress:  12 0.025804826989769936
    grad:  1.0 2.0 tensor([0.03970.03970.0397])
    grad:  2.0 4.0 tensor([-2.7330, -1.3665, -0.6832])
    grad:  3.0 6.0 tensor([2.82430.94140.3138])
progress:  13 0.02462013065814972
    grad:  1.0 2.0 tensor([0.05140.05140.0514])
    grad:  2.0 4.0 tensor([-2.6934, -1.3467, -0.6734])
    grad:  3.0 6.0 tensor([2.77560.92520.3084])
progress:  14 0.023777369409799576
    grad:  1.0 2.0 tensor([0.06240.06240.0624])
    grad:  2.0 4.0 tensor([-2.6580, -1.3290, -0.6645])
    grad:  3.0 6.0 tensor([2.72130.90710.3024])
progress:  15 0.0228563379496336
    grad:  1.0 2.0 tensor([0.07310.07310.0731])
    grad:  2.0 4.0 tensor([-2.6227, -1.3113, -0.6557])
    grad:  3.0 6.0 tensor([2.67250.89080.2969])
progress:  16 0.022044027224183083
    grad:  1.0 2.0 tensor([0.08330.08330.0833])
    grad:  2.0 4.0 tensor([-2.5893, -1.2946, -0.6473])
    grad:  3.0 6.0 tensor([2.62400.87470.2916])
progress:  17 0.02125072106719017
    grad:  1.0 2.0 tensor([0.09310.09310.0931])
    grad:  2.0 4.0 tensor([-2.5568, -1.2784, -0.6392])
    grad:  3.0 6.0 tensor([2.57800.85930.2864])
progress:  18 0.020513182505965233
    grad:  1.0 2.0 tensor([0.10250.10250.1025])
    grad:  2.0 4.0 tensor([-2.5258, -1.2629, -0.6314])
    grad:  3.0 6.0 tensor([2.53350.84450.2815])
progress:  19 0.019810274243354797
    grad:  1.0 2.0 tensor([0.11160.11160.1116])
    grad:  2.0 4.0 tensor([-2.4958, -1.2479, -0.6239])
    grad:  3.0 6.0 tensor([2.49080.83030.2768])
progress:  20 0.019148115068674088
    grad:  1.0 2.0 tensor([0.12030.12030.1203])
    grad:  2.0 4.0 tensor([-2.4669, -1.2335, -0.6167])
    grad:  3.0 6.0 tensor([2.44960.81650.2722])
progress:  21 0.018520694226026535
    grad:  1.0 2.0 tensor([0.12860.12860.1286])
    grad:  2.0 4.0 tensor([-2.4392, -1.2196, -0.6098])
    grad:  3.0 6.0 tensor([2.41010.80340.2678])
progress:  22 0.017927465960383415
    grad:  1.0 2.0 tensor([0.13670.13670.1367])
    grad:  2.0 4.0 tensor([-2.4124, -1.2062, -0.6031])
    grad:  3.0 6.0 tensor([2.37200.79070.2636])
progress:  23 0.01736525259912014
    grad:  1.0 2.0 tensor([0.14440.14440.1444])
    grad:  2.0 4.0 tensor([-2.3867, -1.1933, -0.5967])
    grad:  3.0 6.0 tensor([2.33540.77850.2595])
progress:  24 0.016833148896694183
    grad:  1.0 2.0 tensor([0.15180.15180.1518])
    grad:  2.0 4.0 tensor([-2.3619, -1.1810, -0.5905])
    grad:  3.0 6.0 tensor([2.30010.76670.2556])
progress:  25 0.01632905937731266
    grad:  1.0 2.0 tensor([0.15890.15890.1589])
    grad:  2.0 4.0 tensor([-2.3380, -1.1690, -0.5845])
    grad:  3.0 6.0 tensor([2.26620.75540.2518])
progress:  26 0.01585075818002224
    grad:  1.0 2.0 tensor([0.16570.16570.1657])
    grad:  2.0 4.0 tensor([-2.3151, -1.1575, -0.5788])
    grad:  3.0 6.0 tensor([2.23360.74450.2482])
progress:  27 0.015397666022181511
    grad:  1.0 2.0 tensor([0.17230.17230.1723])
    grad:  2.0 4.0 tensor([-2.2929, -1.1465, -0.5732])
    grad:  3.0 6.0 tensor([2.20220.73410.2447])
progress:  28 0.014967591501772404
    grad:  1.0 2.0 tensor([0.17860.17860.1786])
    grad:  2.0 4.0 tensor([-2.2716, -1.1358, -0.5679])
    grad:  3.0 6.0 tensor([2.17190.72400.2413])
progress:  29 0.014559715054929256
    grad:  1.0 2.0 tensor([0.18460.18460.1846])
    grad:  2.0 4.0 tensor([-2.2511, -1.1255, -0.5628])
    grad:  3.0 6.0 tensor([2.14290.71430.2381])
progress:  30 0.014172340743243694
    grad:  1.0 2.0 tensor([0.19040.19040.1904])
    grad:  2.0 4.0 tensor([-2.2313, -1.1157, -0.5578])
    grad:  3.0 6.0 tensor([2.11490.70500.2350])
progress:  31 0.013804304413497448
    grad:  1.0 2.0 tensor([0.19600.19600.1960])
    grad:  2.0 4.0 tensor([-2.2123, -1.1061, -0.5531])
    grad:  3.0 6.0 tensor([2.08790.69600.2320])
progress:  32 0.013455045409500599
    grad:  1.0 2.0 tensor([0.20140.20140.2014])
    grad:  2.0 4.0 tensor([-2.1939, -1.0970, -0.5485])
    grad:  3.0 6.0 tensor([2.06200.68730.2291])
progress:  33 0.013122711330652237
    grad:  1.0 2.0 tensor([0.20650.20650.2065])
    grad:  2.0 4.0 tensor([-2.1763, -1.0881, -0.5441])
    grad:  3.0 6.0 tensor([2.03700.67900.2263])
progress:  34 0.01280694268643856
    grad:  1.0 2.0 tensor([0.21140.21140.2114])
    grad:  2.0 4.0 tensor([-2.1592, -1.0796, -0.5398])
    grad:  3.0 6.0 tensor([2.01300.67100.2237])
progress:  35 0.012506747618317604
    grad:  1.0 2.0 tensor([0.21620.21620.2162])
    grad:  2.0 4.0 tensor([-2.1428, -1.0714, -0.5357])
    grad:  3.0 6.0 tensor([1.98990.66330.2211])
progress:  36 0.012220758944749832
    grad:  1.0 2.0 tensor([0.22070.22070.2207])
    grad:  2.0 4.0 tensor([-2.1270, -1.0635, -0.5317])
    grad:  3.0 6.0 tensor([1.96760.65590.2186])
progress:  37 0.01194891706109047
    grad:  1.0 2.0 tensor([0.22510.22510.2251])
    grad:  2.0 4.0 tensor([-2.1118, -1.0559, -0.5279])
    grad:  3.0 6.0 tensor([1.94620.64870.2162])
progress:  38 0.011689926497638226
    grad:  1.0 2.0 tensor([0.22920.22920.2292])
    grad:  2.0 4.0 tensor([-2.0971, -1.0485, -0.5243])
    grad:  3.0 6.0 tensor([1.92550.64180.2139])
progress:  39 0.01144315768033266
    grad:  1.0 2.0 tensor([0.23330.23330.2333])
    grad:  2.0 4.0 tensor([-2.0829, -1.0414, -0.5207])
    grad:  3.0 6.0 tensor([1.90570.63520.2117])
progress:  40 0.011208509095013142
    grad:  1.0 2.0 tensor([0.23710.23710.2371])
    grad:  2.0 4.0 tensor([-2.0693, -1.0346, -0.5173])
    grad:  3.0 6.0 tensor([1.88650.62880.2096])
progress:  41 0.0109840864315629
    grad:  1.0 2.0 tensor([0.24080.24080.2408])
    grad:  2.0 4.0 tensor([-2.0561, -1.0280, -0.5140])
    grad:  3.0 6.0 tensor([1.86810.62270.2076])
progress:  42 0.010770938359200954
    grad:  1.0 2.0 tensor([0.24440.24440.2444])
    grad:  2.0 4.0 tensor([-2.0434, -1.0217, -0.5108])
    grad:  3.0 6.0 tensor([1.85030.61680.2056])
progress:  43 0.010566935874521732
    grad:  1.0 2.0 tensor([0.24780.24780.2478])
    grad:  2.0 4.0 tensor([-2.0312, -1.0156, -0.5078])
    grad:  3.0 6.0 tensor([1.83320.61110.2037])
progress:  44 0.010372749529778957
    grad:  1.0 2.0 tensor([0.25100.25100.2510])
    grad:  2.0 4.0 tensor([-2.0194, -1.0097, -0.5048])
    grad:  3.0 6.0 tensor([1.81680.60560.2019])
progress:  45 0.010187389329075813
    grad:  1.0 2.0 tensor([0.25420.25420.2542])

    grad:  2.0 4.0 tensor([-2.0080, -1.0040, -0.5020])
    grad:  3.0 6.0 tensor([1.80090.60030.2001])
progress:  46 0.010010283440351486
    grad:  1.0 2.0 tensor([0.25720.25720.2572])
    grad:  2.0 4.0 tensor([-1.9970, -0.9985, -0.4992])
    grad:  3.0 6.0 tensor([1.78560.59520.1984])
progress:  47 0.00984097272157669
    grad:  1.0 2.0 tensor([0.26000.26000.2600])
    grad:  2.0 4.0 tensor([-1.9864, -0.9932, -0.4966])
    grad:  3.0 6.0 tensor([1.77090.59030.1968])
progress:  48 0.009679674170911312
    grad:  1.0 2.0 tensor([0.26280.26280.2628])
    grad:  2.0 4.0 tensor([-1.9762, -0.9881, -0.4940])
    grad:  3.0 6.0 tensor([1.75680.58560.1952])
progress:  49 0.009525291621685028
    grad:  1.0 2.0 tensor([0.26550.26550.2655])
    grad:  2.0 4.0 tensor([-1.9663, -0.9832, -0.4916])
    grad:  3.0 6.0 tensor([1.74310.58100.1937])
progress:  50 0.00937769003212452
    grad:  1.0 2.0 tensor([0.26800.26800.2680])
    grad:  2.0 4.0 tensor([-1.9568, -0.9784, -0.4892])
    grad:  3.0 6.0 tensor([1.72990.57660.1922])
progress:  51 0.009236648678779602
    grad:  1.0 2.0 tensor([0.27040.27040.2704])
    grad:  2.0 4.0 tensor([-1.9476, -0.9738, -0.4869])
    grad:  3.0 6.0 tensor([1.71720.57240.1908])
progress:  52 0.00910158734768629
    grad:  1.0 2.0 tensor([0.27280.27280.2728])
    grad:  2.0 4.0 tensor([-1.9387, -0.9694, -0.4847])
    grad:  3.0 6.0 tensor([1.70500.56830.1894])
progress:  53 0.00897257961332798
    grad:  1.0 2.0 tensor([0.27500.27500.2750])
    grad:  2.0 4.0 tensor([-1.9301, -0.9651, -0.4825])
    grad:  3.0 6.0 tensor([1.69320.56440.1881])
progress:  54 0.008848887868225574
    grad:  1.0 2.0 tensor([0.27710.27710.2771])
    grad:  2.0 4.0 tensor([-1.9219, -0.9609, -0.4805])
    grad:  3.0 6.0 tensor([1.68190.56060.1869])
progress:  55 0.008730598725378513
    grad:  1.0 2.0 tensor([0.27920.27920.2792])
    grad:  2.0 4.0 tensor([-1.9139, -0.9569, -0.4785])
    grad:  3.0 6.0 tensor([1.67090.55700.1857])
progress:  56 0.00861735362559557
    grad:  1.0 2.0 tensor([0.28110.28110.2811])
    grad:  2.0 4.0 tensor([-1.9062, -0.9531, -0.4765])
    grad:  3.0 6.0 tensor([1.66040.55350.1845])
progress:  57 0.008508718572556973
    grad:  1.0 2.0 tensor([0.28300.28300.2830])
    grad:  2.0 4.0 tensor([-1.8987, -0.9493, -0.4747])
    grad:  3.0 6.0 tensor([1.65020.55010.1834])
progress:  58 0.008404706604778767
    grad:  1.0 2.0 tensor([0.28480.28480.2848])
    grad:  2.0 4.0 tensor([-1.8915, -0.9457, -0.4729])
    grad:  3.0 6.0 tensor([1.64040.54680.1823])
progress:  59 0.008305158466100693
    grad:  1.0 2.0 tensor([0.28650.28650.2865])
    grad:  2.0 4.0 tensor([-1.8845, -0.9423, -0.4711])
    grad:  3.0 6.0 tensor([1.63090.54360.1812])
progress:  60 0.00820931326597929
    grad:  1.0 2.0 tensor([0.28820.28820.2882])
    grad:  2.0 4.0 tensor([-1.8778, -0.9389, -0.4694])
    grad:  3.0 6.0 tensor([1.62180.54060.1802])
progress:  61 0.008117804303765297
    grad:  1.0 2.0 tensor([0.28980.28980.2898])
    grad:  2.0 4.0 tensor([-1.8713, -0.9356, -0.4678])
    grad:  3.0 6.0 tensor([1.61300.53770.1792])
progress:  62 0.008029798977077007
    grad:  1.0 2.0 tensor([0.29130.29130.2913])
    grad:  2.0 4.0 tensor([-1.8650, -0.9325, -0.4662])
    grad:  3.0 6.0 tensor([1.60450.53480.1783])
progress:  63 0.007945418357849121
    grad:  1.0 2.0 tensor([0.29270.29270.2927])
    grad:  2.0 4.0 tensor([-1.8589, -0.9294, -0.4647])
    grad:  3.0 6.0 tensor([1.59620.53210.1774])
progress:  64 0.007864190265536308
    grad:  1.0 2.0 tensor([0.29410.29410.2941])
    grad:  2.0 4.0 tensor([-1.8530, -0.9265, -0.4632])
    grad:  3.0 6.0 tensor([1.58840.52950.1765])
progress:  65 0.007786744274199009
    grad:  1.0 2.0 tensor([0.29540.29540.2954])
    grad:  2.0 4.0 tensor([-1.8473, -0.9236, -0.4618])
    grad:  3.0 6.0 tensor([1.58070.52690.1756])
progress:  66 0.007711691781878471
    grad:  1.0 2.0 tensor([0.29670.29670.2967])
    grad:  2.0 4.0 tensor([-1.8417, -0.9209, -0.4604])
    grad:  3.0 6.0 tensor([1.57330.52440.1748])
progress:  67 0.007640169933438301
    grad:  1.0 2.0 tensor([0.29790.29790.2979])
    grad:  2.0 4.0 tensor([-1.8364, -0.9182, -0.4591])
    grad:  3.0 6.0 tensor([1.56620.52210.1740])
progress:  68 0.007570972666144371
    grad:  1.0 2.0 tensor([0.29910.29910.2991])
    grad:  2.0 4.0 tensor([-1.8312, -0.9156, -0.4578])
    grad:  3.0 6.0 tensor([1.55930.51980.1733])
progress:  69 0.007504733745008707
    grad:  1.0 2.0 tensor([0.30020.30020.3002])
    grad:  2.0 4.0 tensor([-1.8262, -0.9131, -0.4566])
    grad:  3.0 6.0 tensor([1.55270.51760.1725])
progress:  70 0.007440924644470215
    grad:  1.0 2.0 tensor([0.30120.30120.3012])
    grad:  2.0 4.0 tensor([-1.8214, -0.9107, -0.4553])
    grad:  3.0 6.0 tensor([1.54630.51540.1718])
progress:  71 0.007379599846899509
    grad:  1.0 2.0 tensor([0.30220.30220.3022])
    grad:  2.0 4.0 tensor([-1.8167, -0.9083, -0.4542])
    grad:  3.0 6.0 tensor([1.54010.51340.1711])
progress:  72 0.007320486940443516
    grad:  1.0 2.0 tensor([0.30320.30320.3032])
    grad:  2.0 4.0 tensor([-1.8121, -0.9060, -0.4530])
    grad:  3.0 6.0 tensor([1.53410.51140.1705])
progress:  73 0.007263725157827139
    grad:  1.0 2.0 tensor([0.30410.30410.3041])
    grad:  2.0 4.0 tensor([-1.8077, -0.9038, -0.4519])
    grad:  3.0 6.0 tensor([1.52830.50940.1698])
progress:  74 0.007209045812487602
    grad:  1.0 2.0 tensor([0.30500.30500.3050])
    grad:  2.0 4.0 tensor([-1.8034, -0.9017, -0.4508])
    grad:  3.0 6.0 tensor([1.52270.50760.1692])
progress:  75 0.007156429346650839
    grad:  1.0 2.0 tensor([0.30580.30580.3058])
    grad:  2.0 4.0 tensor([-1.7992, -0.8996, -0.4498])
    grad:  3.0 6.0 tensor([1.51730.50580.1686])
progress:  76 0.007105532102286816
    grad:  1.0 2.0 tensor([0.30660.30660.3066])
    grad:  2.0 4.0 tensor([-1.7952, -0.8976, -0.4488])
    grad:  3.0 6.0 tensor([1.51210.50400.1680])
progress:  77 0.00705681974068284
    grad:  1.0 2.0 tensor([0.30730.30730.3073])
    grad:  2.0 4.0 tensor([-1.7913, -0.8956, -0.4478])
    grad:  3.0 6.0 tensor([1.50700.50230.1674])
progress:  78 0.007009552326053381
    grad:  1.0 2.0 tensor([0.30810.30810.3081])
    grad:  2.0 4.0 tensor([-1.7875, -0.8937, -0.4469])
    grad:  3.0 6.0 tensor([1.50210.50070.1669])
progress:  79 0.006964194122701883
    grad:  1.0 2.0 tensor([0.30870.30870.3087])
    grad:  2.0 4.0 tensor([-1.7838, -0.8919, -0.4459])
    grad:  3.0 6.0 tensor([1.49740.49910.1664])
progress:  80 0.006920332089066505
    grad:  1.0 2.0 tensor([0.30940.30940.3094])
    grad:  2.0 4.0 tensor([-1.7802, -0.8901, -0.4450])
    grad:  3.0 6.0 tensor([1.49280.49760.1659])
progress:  81 0.006878111511468887
    grad:  1.0 2.0 tensor([0.31000.31000.3100])
    grad:  2.0 4.0 tensor([-1.7767, -0.8883, -0.4442])
    grad:  3.0 6.0 tensor([1.48840.49610.1654])
progress:  82 0.006837360095232725
    grad:  1.0 2.0 tensor([0.31060.31060.3106])
    grad:  2.0 4.0 tensor([-1.7733, -0.8867, -0.4433])
    grad:  3.0 6.0 tensor([1.48410.49470.1649])
progress:  83 0.006797831039875746
    grad:  1.0 2.0 tensor([0.31110.31110.3111])
    grad:  2.0 4.0 tensor([-1.7700, -0.8850, -0.4425])
    grad:  3.0 6.0 tensor([1.48000.49330.1644])
progress:  84 0.006760062649846077
    grad:  1.0 2.0 tensor([0.31170.31170.3117])
    grad:  2.0 4.0 tensor([-1.7668, -0.8834, -0.4417])
    grad:  3.0 6.0 tensor([1.47590.49200.1640])
progress:  85 0.006723103579133749
    grad:  1.0 2.0 tensor([0.31220.31220.3122])
    grad:  2.0 4.0 tensor([-1.7637, -0.8818, -0.4409])
    grad:  3.0 6.0 tensor([1.47200.49070.1636])
progress:  86 0.00668772729113698
    grad:  1.0 2.0 tensor([0.31270.31270.3127])
    grad:  2.0 4.0 tensor([-1.7607, -0.8803, -0.4402])
    grad:  3.0 6.0 tensor([1.46820.48940.1631])
progress:  87 0.006653300020843744
    grad:  1.0 2.0 tensor([0.31310.31310.3131])
    grad:  2.0 4.0 tensor([-1.7577, -0.8789, -0.4394])
    grad:  3.0 6.0 tensor([1.46460.48820.1627])
progress:  88 0.0066203586757183075
    grad:  1.0 2.0 tensor([0.31350.31350.3135])
    grad:  2.0 4.0 tensor([-1.7548, -0.8774, -0.4387])
    grad:  3.0 6.0 tensor([1.46100.48700.1623])
progress:  89 0.0065881176851689816
    grad:  1.0 2.0 tensor([0.31390.31390.3139])
    grad:  2.0 4.0 tensor([-1.7520, -0.8760, -0.4380])
    grad:  3.0 6.0 tensor([1.45760.48590.1620])
progress:  90 0.0065572685562074184
    grad:  1.0 2.0 tensor([0.31430.31430.3143])
    grad:  2.0 4.0 tensor([-1.7493, -0.8747, -0.4373])
    grad:  3.0 6.0 tensor([1.45420.48470.1616])
progress:  91 0.0065271081402897835
    grad:  1.0 2.0 tensor([0.31470.31470.3147])
    grad:  2.0 4.0 tensor([-1.7466, -0.8733, -0.4367])
    grad:  3.0 6.0 tensor([1.45100.48370.1612])
progress:  92 0.00649801641702652
    grad:  1.0 2.0 tensor([0.31500.31500.3150])
    grad:  2.0 4.0 tensor([-1.7441, -0.8720, -0.4360])
    grad:  3.0 6.0 tensor([1.44780.48260.1609])
progress:  93 0.0064699104987084866
    grad:  1.0 2.0 tensor([0.31530.31530.3153])
    grad:  2.0 4.0 tensor([-1.7415, -0.8708, -0.4354])
    grad:  3.0 6.0 tensor([1.44480.48160.1605])
progress:  94 0.006442630663514137
    grad:  1.0 2.0 tensor([0.31560.31560.3156])
    grad:  2.0 4.0 tensor([-1.7391, -0.8695, -0.4348])
    grad:  3.0 6.0 tensor([1.44180.48060.1602])
progress:  95 0.006416172254830599
    grad:  1.0 2.0 tensor([0.31590.31590.3159])
    grad:  2.0 4.0 tensor([-1.7366, -0.8683, -0.4342])
    grad:  3.0 6.0 tensor([1.43890.47960.1599])
progress:  96 0.006390606984496117
    grad:  1.0 2.0 tensor([0.31610.31610.3161])
    grad:  2.0 4.0 tensor([-1.7343, -0.8671, -0.4336])
    grad:  3.0 6.0 tensor([1.43610.47870.1596])
progress:  97 0.0063657015562057495
    grad:  1.0 2.0 tensor([0.31640.31640.3164])
    grad:  2.0 4.0 tensor([-1.7320, -0.8660, -0.4330])
    grad:  3.0 6.0 tensor([1.43340.47780.1593])
progress:  98 0.0063416799530386925
    grad:  1.0 2.0 tensor([0.31660.31660.3166])
    grad:  2.0 4.0 tensor([-1.7297, -0.8649, -0.4324])
    grad:  3.0 6.0 tensor([1.43080.47690.1590])
progress:  99 0.00631808303296566
predict (after tranining)  4 8.544171333312988

损失值随着迭代次数的增加呈递减趋势,如下图所示:

PyTorch梯度下降反向传播实例分析

可以看出:x=4时的预测值约为8.5,与真实值8有所差距,可通过提高迭代次数或者调整学习率、初始参数等方法来减小差距。

读到这里,这篇“PyTorch梯度下降反向传播实例分析”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注亿速云行业资讯频道。

亿速云「云服务器」,即开即用、新一代英特尔至强铂金CPU、三副本存储NVMe SSD云盘,价格低至29元/月。点击查看>>

向AI问一下细节

免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。

AI

开发者交流群×