-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsearch.xml
766 lines (364 loc) · 878 KB
/
search.xml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
<?xml version="1.0" encoding="utf-8"?>
<search>
<entry>
<title>[阿里云天池] 自然语言处理训练营 2</title>
<link href="2020/12/29/%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-%E8%87%AA%E7%84%B6%E8%AF%AD%E8%A8%80%E5%A4%84%E7%90%86%E8%AE%AD%E7%BB%83%E8%90%A5-2/"/>
<url>2020/12/29/%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-%E8%87%AA%E7%84%B6%E8%AF%AD%E8%A8%80%E5%A4%84%E7%90%86%E8%AE%AD%E7%BB%83%E8%90%A5-2/</url>
<content type="html"><![CDATA[<h1 id="文本表示方法-Part1"><a href="#文本表示方法-Part1" class="headerlink" title="文本表示方法 Part1"></a>文本表示方法 Part1</h1><p>在机器学习算法的训练过程中,假设给定 $N$ 个样本,每个样本有 $M$ 个特征,这样组成了 $N × M$ 的样本矩阵,然后完成算法的训练和预测。同样的在计算机视觉中可以将图片的像素看作特征,每张图片看作 $hight×width×3$的特征图,一个三维的矩阵来进入计算机进行计算。</p><a id="more"></a><p>但是在自然语言领域,上述方法却不可行:文本是不定长度的。文本表示成计算机能够运算的数字或向量的方法一般称为词嵌入(Word Embedding)方法。词嵌入将不定长的文本转换到定长的空间内,是文本分类的第一步。</p><h2 id="One-hot"><a href="#One-hot" class="headerlink" title="One-hot"></a>One-hot</h2><p>这里的One-hot与数据挖掘任务中的操作是一致的,即将每一个单词使用一个离散的向量表示。具体将每个字/词编码一个索引,然后根据索引进行赋值。</p><h2 id="Bag-of-Words"><a href="#Bag-of-Words" class="headerlink" title="Bag of Words"></a>Bag of Words</h2><p>Bag of Words(词袋表示),也称为Count Vectors,每个文档的字/词可以使用其出现次数来进行表示。<br>在 sklearn 中可以直接 <code>CountVectorizer</code> 来实现这一步骤:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> sklearn.feature_extraction.text <span class="keyword">import</span> CountVectorizer</span><br><span class="line">corpus = [</span><br><span class="line"> <span class="string">'This is the first document.'</span>,</span><br><span class="line"> <span class="string">'This document is the second document.'</span>,</span><br><span class="line"> <span class="string">'And this is the third one.'</span>,</span><br><span class="line"> <span class="string">'Is this the first document?'</span>,</span><br><span class="line">]</span><br><span class="line">vectorizer = CountVectorizer()</span><br><span class="line">vectorizer.fit_transform(corpus).toarray()</span><br></pre></td></tr></table></figure><h2 id="N-gram"><a href="#N-gram" class="headerlink" title="N-gram"></a>N-gram</h2><p>N-gram与Count Vectors类似,不过加入了相邻单词组合成为新的单词,并进行计数。</p><h2 id="TF-IDF"><a href="#TF-IDF" class="headerlink" title="TF-IDF"></a>TF-IDF</h2><p>TF-IDF 分数由两部分组成:第一部分是词语频率(Term Frequency),第二部分是逆文档频率(Inverse Document Frequency)。其中计算语料库中文档总数除以含有该词语的文档数量,然后再取对数就是逆文档频率。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">TF(t)= 该词语在当前文档出现的次数 / 当前文档中词语的总数</span><br><span class="line">IDF(t)= log_e(文档总数 / 出现该词语的文档总数)</span><br></pre></td></tr></table></figure><h1 id="基于机器学习的文本分类"><a href="#基于机器学习的文本分类" class="headerlink" title="基于机器学习的文本分类"></a>基于机器学习的文本分类</h1><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># Count Vectors + RidgeClassifier</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> sklearn.feature_extraction.text <span class="keyword">import</span> CountVectorizer</span><br><span class="line"><span class="keyword">from</span> sklearn.linear_model <span class="keyword">import</span> RidgeClassifier</span><br><span class="line"><span class="keyword">from</span> sklearn.metrics <span class="keyword">import</span> f1_score</span><br><span class="line"></span><br><span class="line">train_df = pd.read_csv(<span class="string">'../data/train_set.csv'</span>, sep=<span class="string">'\t'</span>, nrows=<span class="number">15000</span>)</span><br><span class="line"></span><br><span class="line">vectorizer = CountVectorizer(max_features=<span class="number">3000</span>)</span><br><span class="line">train_test = vectorizer.fit_transform(train_df[<span class="string">'text'</span>])</span><br><span class="line"></span><br><span class="line">clf = RidgeClassifier()</span><br><span class="line">clf.fit(train_test[:<span class="number">10000</span>], train_df[<span class="string">'label'</span>].values[:<span class="number">10000</span>])</span><br><span class="line"></span><br><span class="line">val_pred = clf.predict(train_test[<span class="number">10000</span>:])</span><br><span class="line">print(f1_score(train_df[<span class="string">'label'</span>].values[<span class="number">10000</span>:], val_pred, average=<span class="string">'macro'</span>))</span><br><span class="line"><span class="comment"># 0.65</span></span><br></pre></td></tr></table></figure><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># TF-IDF + RidgeClassifier</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> sklearn.feature_extraction.text <span class="keyword">import</span> TfidfVectorizer</span><br><span class="line"><span class="keyword">from</span> sklearn.linear_model <span class="keyword">import</span> RidgeClassifier</span><br><span class="line"><span class="keyword">from</span> sklearn.metrics <span class="keyword">import</span> f1_score</span><br><span class="line"></span><br><span class="line">train_df = pd.read_csv(<span class="string">'../data/train_set.csv'</span>, sep=<span class="string">'\t'</span>, nrows=<span class="number">15000</span>)</span><br><span class="line"></span><br><span class="line">tfidf = TfidfVectorizer(ngram_range=(<span class="number">1</span>,<span class="number">3</span>), max_features=<span class="number">3000</span>)</span><br><span class="line">train_test = tfidf.fit_transform(train_df[<span class="string">'text'</span>])</span><br><span class="line"></span><br><span class="line">clf = RidgeClassifier()</span><br><span class="line">clf.fit(train_test[:<span class="number">10000</span>], train_df[<span class="string">'label'</span>].values[:<span class="number">10000</span>])</span><br><span class="line"></span><br><span class="line">val_pred = clf.predict(train_test[<span class="number">10000</span>:])</span><br><span class="line">print(f1_score(train_df[<span class="string">'label'</span>].values[<span class="number">10000</span>:], val_pred, average=<span class="string">'macro'</span>))</span><br><span class="line"><span class="comment"># 0.87</span></span><br></pre></td></tr></table></figure><h1 id="作业"><a href="#作业" class="headerlink" title="作业"></a>作业</h1><ol><li>尝试改变TF-IDF的参数,并验证精度<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># TF-IDF (6000 words)+ RidgeClassifier</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> sklearn.feature_extraction.text <span class="keyword">import</span> TfidfVectorizer</span><br><span class="line"><span class="keyword">from</span> sklearn.linear_model <span class="keyword">import</span> RidgeClassifier</span><br><span class="line"><span class="keyword">from</span> sklearn.metrics <span class="keyword">import</span> f1_score</span><br><span class="line"></span><br><span class="line">train_df = pd.read_csv(<span class="string">'./data/train_set.csv'</span>, sep=<span class="string">'\t'</span>, nrows=<span class="number">15000</span>)</span><br><span class="line"></span><br><span class="line">tfidf = TfidfVectorizer(ngram_range=(<span class="number">1</span>,<span class="number">3</span>), max_features=<span class="number">6000</span>)</span><br><span class="line">train_test = tfidf.fit_transform(train_df[<span class="string">'text'</span>])</span><br><span class="line"></span><br><span class="line">clf = RidgeClassifier()</span><br><span class="line">clf.fit(train_test[:<span class="number">10000</span>], train_df[<span class="string">'label'</span>].values[:<span class="number">10000</span>])</span><br><span class="line"></span><br><span class="line">val_pred = clf.predict(train_test[<span class="number">10000</span>:])</span><br><span class="line">print(f1_score(train_df[<span class="string">'label'</span>].values[<span class="number">10000</span>:], val_pred, average=<span class="string">'macro'</span>))</span><br></pre></td></tr></table></figure></li><li>尝试使用其他机器学习模型,完成训练和验证</li></ol>]]></content>
<tags>
<tag> NLP </tag>
</tags>
</entry>
<entry>
<title>[阿里云天池] 自然语言处理训练营 1</title>
<link href="2020/12/29/%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-%E8%87%AA%E7%84%B6%E8%AF%AD%E8%A8%80%E5%A4%84%E7%90%86%E8%AE%AD%E7%BB%83%E8%90%A5-1/"/>
<url>2020/12/29/%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-%E8%87%AA%E7%84%B6%E8%AF%AD%E8%A8%80%E5%A4%84%E7%90%86%E8%AE%AD%E7%BB%83%E8%90%A5-1/</url>
<content type="html"><![CDATA[<h1 id="数据读取"><a href="#数据读取" class="headerlink" title="数据读取"></a>数据读取</h1><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line">train_df = pd.read_csv(<span class="string">'./data/train_set.csv'</span>, sep=<span class="string">'\t'</span>, nrows=<span class="number">100</span>)</span><br><span class="line">train_df.head()</span><br></pre></td></tr></table></figure><table><thead><tr><th></th><th>label</th><th>text</th></tr></thead><tbody><tr><td>0</td><td>2</td><td>2967 6758 339 2021 1854 3731 4109 3792 4149 15…</td></tr><tr><td>1</td><td>11</td><td>4464 486 6352 5619 2465 4802 1452 3137 5778 54…</td></tr><tr><td>2</td><td>3</td><td>7346 4068 5074 3747 5681 6093 1777 2226 7354 6…</td></tr><tr><td>3</td><td>2</td><td>7159 948 4866 2109 5520 2490 211 3956 5520 549…</td></tr><tr><td>4</td><td>3</td><td>3646 3055 3055 2490 4659 6065 3370 5814 2465 5…</td></tr></tbody></table><a id="more"></a><h1 id="数据分析"><a href="#数据分析" class="headerlink" title="数据分析"></a>数据分析</h1><p>在读取完成数据集后,我们还可以对数据集进行数据分析的操作。虽然对于非结构数据并不需要做很多的数据分析,但通过数据分析还是可以找出一些规律的。</p><p>此步骤我们读取了所有的训练集数据,在此我们通过数据分析希望得出以下结论:</p><ul><li>赛题数据中,新闻文本的长度是多少?</li><li>赛题数据的类别分布是怎么样的,哪些类别比较多?</li><li>赛题数据中,字符分布是怎么样的?<h2 id="句子长度分析"><a href="#句子长度分析" class="headerlink" title="句子长度分析"></a>句子长度分析</h2></li></ul><p>在赛题数据中每行句子的字符使用空格进行隔开,所以可以直接统计单词的个数来得到每个句子的长度。统计并如下:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line">%pylab inline</span><br><span class="line">train_df[<span class="string">'text_len'</span>] = train_df[<span class="string">'text'</span>].apply(<span class="keyword">lambda</span> x: len(x.split(<span class="string">' '</span>)))</span><br><span class="line">print(train_df[<span class="string">'text_len'</span>].describe())</span><br><span class="line"></span><br><span class="line">Populating the interactive namespace <span class="keyword">from</span> numpy <span class="keyword">and</span> matplotlib</span><br><span class="line">count <span class="number">100.000000</span></span><br><span class="line">mean <span class="number">872.320000</span></span><br><span class="line">std <span class="number">923.138191</span></span><br><span class="line">min <span class="number">64.000000</span></span><br><span class="line"><span class="number">25</span>% <span class="number">359.500000</span></span><br><span class="line"><span class="number">50</span>% <span class="number">598.000000</span></span><br><span class="line"><span class="number">75</span>% <span class="number">1058.000000</span></span><br><span class="line">max <span class="number">7125.000000</span></span><br><span class="line">Name: text_len, dtype: float64</span><br></pre></td></tr></table></figure><h2 id="新闻类别分布"><a href="#新闻类别分布" class="headerlink" title="新闻类别分布"></a>新闻类别分布</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">train_df[<span class="string">'label'</span>].value_counts().plot(kind=<span class="string">'bar'</span>)</span><br><span class="line">plt.title(<span class="string">'News class count'</span>)</span><br><span class="line">plt.xlabel(<span class="string">"category"</span>)</span><br></pre></td></tr></table></figure><p>在数据集中标签的对应的关系如下:{‘科技’: 0, ‘股票’: 1, ‘体育’: 2, ‘娱乐’: 3, ‘时政’: 4, ‘社会’: 5, ‘教育’: 6, ‘财经’: 7, ‘家居’: 8, ‘游戏’: 9, ‘房产’: 10, ‘时尚’: 11, ‘彩票’: 12, ‘星座’: 13}</p><p>从统计结果可以看出,赛题的数据集类别分布存在较为不均匀的情况。在训练集中科技类新闻最多,其次是股票类新闻,最少的新闻是星座新闻。</p><h2 id="字符分布统计"><a href="#字符分布统计" class="headerlink" title="字符分布统计"></a>字符分布统计</h2><p>接下来可以统计每个字符出现的次数,首先可以将训练集中所有的句子进行拼接进而划分为字符,并统计每个字符的个数。</p><p>从统计结果中可以看出,在训练集中总共包括6869个字,其中编号3750的字出现的次数最多,编号3133的字出现的次数最少。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> collections <span class="keyword">import</span> Counter</span><br><span class="line">all_lines = <span class="string">' '</span>.join(list(train_df[<span class="string">'text'</span>]))</span><br><span class="line">word_count = Counter(all_lines.split(<span class="string">" "</span>))</span><br><span class="line">word_count = sorted(word_count.items(), key=<span class="keyword">lambda</span> d:d[<span class="number">1</span>], reverse = <span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">print(len(word_count)) <span class="comment"># 2405</span></span><br><span class="line"></span><br><span class="line">print(word_count[<span class="number">0</span>]) <span class="comment"># ('3750', 3702)</span></span><br><span class="line"></span><br><span class="line">print(word_count[<span class="number">-1</span>]) <span class="comment"># ('5034', 1)</span></span><br></pre></td></tr></table></figure><p>这里还可以根据字在每个句子的出现情况,反推出标点符号。下面代码统计了不同字符在句子中出现的次数,其中字符3750,字符900和字符648在20w新闻的覆盖率接近99%,很有可能是标点符号。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> collections <span class="keyword">import</span> Counter</span><br><span class="line">train_df[<span class="string">'text_unique'</span>] = train_df[<span class="string">'text'</span>].apply(<span class="keyword">lambda</span> x: <span class="string">' '</span>.join(list(set(x.split(<span class="string">' '</span>)))))</span><br><span class="line">all_lines = <span class="string">' '</span>.join(list(train_df[<span class="string">'text_unique'</span>]))</span><br><span class="line">word_count = Counter(all_lines.split(<span class="string">" "</span>))</span><br><span class="line">word_count = sorted(word_count.items(), key=<span class="keyword">lambda</span> d:int(d[<span class="number">1</span>]), reverse = <span class="literal">True</span>)</span><br><span class="line"></span><br><span class="line">print(word_count[:<span class="number">5</span>]) <span class="comment"># [('900', 99), ('3750', 99), ('648', 96), ('7399', 87), ('2109', 86)]</span></span><br></pre></td></tr></table></figure><h1 id="数据分析的结论"><a href="#数据分析的结论" class="headerlink" title="数据分析的结论"></a>数据分析的结论</h1><p>通过上述分析我们可以得出以下结论:</p><ol><li>赛题中每个新闻包含的字符个数平均为1000个,还有一些新闻字符较长;</li><li>赛题中新闻类别分布不均匀,科技类新闻样本量接近4w,星座类新闻样本量不到1k;</li><li>赛题总共包括7000-8000个字符;</li></ol><p>通过数据分析,我们还可以得出以下结论:</p><ol><li>每个新闻平均字符个数较多,可能需要截断;</li><li>由于类别不均衡,会严重影响模型的精度;</li></ol><h1 id="本章作业"><a href="#本章作业" class="headerlink" title="本章作业"></a>本章作业</h1><ol><li>假设字符3750,字符900和字符648是句子的标点符号,请分析赛题每篇新闻平均由多少个句子构成?<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">train_df[<span class="string">'length'</span>] = train_df[<span class="string">'text'</span>].apply(<span class="keyword">lambda</span> x: x.count(<span class="string">'3750'</span>) + x.count(<span class="string">'900'</span>) + x.count(<span class="string">'648'</span>))</span><br><span class="line">print(sum(train_df[<span class="string">'length'</span>]) / <span class="number">100</span>)</span><br></pre></td></tr></table></figure></li><li>统计每类新闻中出现次数对多的字符<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line">lst = [[] <span class="keyword">for</span> _ <span class="keyword">in</span> range(<span class="number">14</span>)]</span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> range(len(train_df)):</span><br><span class="line"> lst[train_df[<span class="string">'label'</span>][i]].extend(train_df[<span class="string">'text'</span>][i].split())</span><br><span class="line">lst_freq = list(map(Counter, lst))</span><br><span class="line">lst_freq = [sorted(lst.items(), key=<span class="keyword">lambda</span> d:int(d[<span class="number">1</span>]), reverse = <span class="literal">True</span>) <span class="keyword">for</span> lst <span class="keyword">in</span> lst_freq]</span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> range(len(lst_freq)):</span><br><span class="line"> print(lst_freq[i][<span class="number">0</span>])</span><br><span class="line"> </span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">610</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">531</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">956</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">239</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">78</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">193</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">491</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">214</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">68</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">51</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">152</span>)</span><br><span class="line">(<span class="string">'3750'</span>, <span class="number">102</span>)</span><br><span class="line">(<span class="string">'4464'</span>, <span class="number">59</span>)</span><br><span class="line">(<span class="string">'648'</span>, <span class="number">6</span>)</span><br></pre></td></tr></table></figure></li></ol>]]></content>
<tags>
<tag> NLP </tag>
</tags>
</entry>
<entry>
<title>[DSU&阿里云天池] Python训练营 Task 4</title>
<link href="2020/12/28/DSU-%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-Python%E8%AE%AD%E7%BB%83%E8%90%A5-Task-4/"/>
<url>2020/12/28/DSU-%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-Python%E8%AE%AD%E7%BB%83%E8%90%A5-Task-4/</url>
<content type="html"><![CDATA[<h1 id="数据载入"><a href="#数据载入" class="headerlink" title="数据载入"></a>数据载入</h1><h2 id="pd-read-csv"><a href="#pd-read-csv" class="headerlink" title="pd.read_csv"></a><code>pd.read_csv</code></h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">pd.read_csv(filepath, sep, names)</span><br></pre></td></tr></table></figure><ul><li><code>filepath</code>:待读取文件的路径</li><li><code>sep</code>:CSV 文件的 delimiter</li><li><code>names</code>:列名的列表<a id="more"></a><h2 id="df-merge"><a href="#df-merge" class="headerlink" title="df.merge"></a><code>df.merge</code></h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">df.merge(left, right)</span><br></pre></td></tr></table></figure></li><li><code>left</code>, <code>right</code>:合并的两个数据框<h2 id="df-head"><a href="#df-head" class="headerlink" title="df.head"></a><code>df.head</code></h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">df.head(n=<span class="number">5</span>)</span><br></pre></td></tr></table></figure>查看最上面的 n 行内容。</li></ul><hr><p><strong>数据载入的要点</strong>:需要将所需数据合并在一起。</p><h1 id="数据清洗"><a href="#数据清洗" class="headerlink" title="数据清洗"></a>数据清洗</h1><h2 id="pd-DataFrame"><a href="#pd-DataFrame" class="headerlink" title="pd.DataFrame"></a><code>pd.DataFrame</code></h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">pd.DataFrame(df, columes)</span><br></pre></td></tr></table></figure><p><code>df</code>:目标数据框<br><code>columns</code>:提取的列</p><h2 id="df-shape"><a href="#df-shape" class="headerlink" title="df.shape"></a><code>df.shape</code></h2><p>获得数据框的维度</p><h2 id="df-info"><a href="#df-info" class="headerlink" title="df.info()"></a><code>df.info()</code></h2><p>打印列表的简单信息</p><h2 id="df-fillna"><a href="#df-fillna" class="headerlink" title="df.fillna"></a><code>df.fillna</code></h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">df.fillna(value, inplace)</span><br></pre></td></tr></table></figure><ul><li><code>value</code>:用来填补缺失值的值</li><li><code>inplace</code>:是否在原地完成操作<h2 id="df-satype"><a href="#df-satype" class="headerlink" title="df.satype"></a><code>df.satype</code></h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">df.astype(dtype)</span><br></pre></td></tr></table></figure></li><li><code>dtype</code>:指定的数据类型<h2 id="df-describe"><a href="#df-describe" class="headerlink" title="df.describe"></a><code>df.describe</code></h2>打印数据框的统计信息</li></ul><hr><p><strong>数据清洗的要点</strong>:检查数据是否有缺失值,如果有,需要采取相应的行为(填补或删除)。</p><h1 id="数据分析"><a href="#数据分析" class="headerlink" title="数据分析"></a>数据分析</h1><h2 id="df-groupby"><a href="#df-groupby" class="headerlink" title="df.groupby"></a><code>df.groupby</code></h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">df.groupby(by)</span><br></pre></td></tr></table></figure><ul><li><code>by</code>:确定整合成一组的标准<h2 id="df-sum"><a href="#df-sum" class="headerlink" title="df.sum"></a><code>df.sum</code></h2>求和<h2 id="df-sort-values"><a href="#df-sort-values" class="headerlink" title="df.sort_values"></a><code>df.sort_values</code></h2><figure class="highlight plain"><figcaption><span>ascending)</span></figcaption><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">```</span><br><span class="line">- `by`:排序依据的行或列</span><br><span class="line">- `ascending`:`True` 则升序排列</span><br><span class="line">## `df.value_counts`</span><br><span class="line">返回不同行的计数</span><br><span class="line">***</span><br><span class="line">**数据分析的要点**:通过合并分组、排序、计数等手段,获得数据的信息。</span><br><span class="line"># 数据可视化</span><br><span class="line">## `df.plot(kind)`</span><br><span class="line">```py</span><br><span class="line">df.plot(kind)</span><br></pre></td></tr></table></figure><code>kind</code>:可视化类型</li><li><code>line</code>:折线图</li><li><code>bar</code>:柱状图</li><li><code>hist</code>:直方图</li><li><code>box</code>:箱型图</li><li><code>kde</code>:密度图</li><li><code>pie</code>:pie 图</li><li><code>scatter</code>:散点图</li></ul><hr><p><strong>数据可视化的要点</strong>:根据不同的数据结构和需求选择合适的可视化方法。</p>]]></content>
<tags>
<tag> Python </tag>
</tags>
</entry>
<entry>
<title>[经验总结]一行代码完成一个任务</title>
<link href="2020/12/26/%E7%BB%8F%E9%AA%8C%E6%80%BB%E7%BB%93-%E4%B8%80%E8%A1%8C%E4%BB%A3%E7%A0%81%E5%AE%8C%E6%88%90%E4%B8%80%E4%B8%AA%E4%BB%BB%E5%8A%A1/"/>
<url>2020/12/26/%E7%BB%8F%E9%AA%8C%E6%80%BB%E7%BB%93-%E4%B8%80%E8%A1%8C%E4%BB%A3%E7%A0%81%E5%AE%8C%E6%88%90%E4%B8%80%E4%B8%AA%E4%BB%BB%E5%8A%A1/</url>
<content type="html"><![CDATA[<p>2020 年的最后一周写一点轻松的东西,熟练使用这些技巧会给日常工作减负不少:本文中会提到几个使用一行代码可以完成简单任务的方法。</p><a id="more"></a><h1 id="替代-if-语句:Ternary-Operator"><a href="#替代-if-语句:Ternary-Operator" class="headerlink" title="替代 if 语句:Ternary Operator"></a>替代 <code>if</code> 语句:Ternary Operator<a href="https://book.pythontips.com/en/latest/ternary_operators.html" target="_blank" rel="noopener" title="Ternary Operator"></a></h1><h2 id="标准用法"><a href="#标准用法" class="headerlink" title="标准用法"></a>标准用法</h2><p>先来看一个简单的 <code>if</code> 语句:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> <span class="number">5</span> > <span class="number">3</span>:</span><br><span class="line"> print(<span class="string">'5 is more than 3'</span>)</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">'5 is less than 3'</span>)</span><br><span class="line"><span class="comment"># '5 is more than 3'</span></span><br></pre></td></tr></table></figure><p>怎么用一行代码完成同样的任务呢?有请 Ternary Operator:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">True_expression <span class="keyword">if</span> True_condition <span class="keyword">else</span> False_expression</span><br></pre></td></tr></table></figure><p>意思是这样的:判断 <code>True_condition</code> 是否成立,如果成立,执行 <code>True_expression</code>;如果不成立,执行 <code>False_expression</code>。那么上面的 <code>if</code> 语句可以写成</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="string">'5 is more than 3'</span> <span class="keyword">if</span> <span class="number">5</span> > <span class="number">3</span> <span class="keyword">else</span> <span class="string">'5 is less than 3'</span></span><br><span class="line"><span class="comment"># '5 is more than 3'</span></span><br></pre></td></tr></table></figure><h2 id="Ternary-Operator-的变体"><a href="#Ternary-Operator-的变体" class="headerlink" title="Ternary Operator 的变体"></a>Ternary Operator 的变体</h2><ol><li>Shorthanded Ternary<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">expression <span class="keyword">or</span> another_expression</span><br></pre></td></tr></table></figure>如果 <code>expression</code> 为 <code>True</code>,那么 <code>another_expression</code> 则不会执行;反之则会执行。比如:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="literal">True</span> <span class="keyword">or</span> <span class="string">'something'</span></span><br><span class="line"><span class="comment"># True</span></span><br><span class="line"><span class="literal">False</span> <span class="keyword">or</span> <span class="string">'something'</span></span><br><span class="line"><span class="comment">#'something'</span></span><br></pre></td></tr></table></figure>因为 <code>None</code> 等于 <code>False</code>,这可以用于检查一个函数是否有输出:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">output = <span class="literal">None</span></span><br><span class="line">msg = output <span class="keyword">or</span> <span class="string">'No data returned'</span></span><br><span class="line"><span class="comment"># 'No data returned'</span></span><br></pre></td></tr></table></figure>也可以用来设定一个函数的动态变量:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">display_name</span><span class="params">(name, default=None)</span>:</span></span><br><span class="line"> displayeded_name = dafault <span class="keyword">or</span> name</span><br><span class="line"> print(displayed_name)</span><br><span class="line">display_name(<span class="string">'Jim'</span>) <span class="comment"># 'Jim'</span></span><br><span class="line">display_name(<span class="string">'Tom'</span>, <span class="string">'anonymous123'</span>) <span class="comment"># 'anonymous123'</span></span><br></pre></td></tr></table></figure></li><li>这个变体同样巧妙地应用了 <code>False == 0</code> 和 <code>True == 1</code> 这两个性质。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">(expr_if_False, expr_if_true)[True_or_False]</span><br></pre></td></tr></table></figure>这里 <code>(expr_if_False, expr_if_true)</code> 用了 tuple,但是也可以用 list。也可以用 <code>bool()</code> 函数将所有变量转换为 <code>0</code> 或 <code>1</code>,因为 <code>False == 0</code>,当 <code>True_or_False</code> 为假时会取第一个元素;当 <code>True_or_False</code> 为真时会取第二个元素。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">output = <span class="string">'blabla'</span></span><br><span class="line">[<span class="string">'Not have an output'</span>, <span class="string">'Has an output'</span>][bool(output)]</span><br><span class="line"><span class="comment"># 'Has an output'</span></span><br></pre></td></tr></table></figure><h1 id="替代-def-语句:lambda-匿名函数"><a href="#替代-def-语句:lambda-匿名函数" class="headerlink" title="替代 def 语句:lambda 匿名函数"></a>替代 <code>def</code> 语句:<code>lambda</code> 匿名函数</h1>再来看一个简单的自定义函数:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">add</span><span class="params">(x, y)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> x + y</span><br><span class="line">add(<span class="number">1</span>, <span class="number">2</span>) <span class="comment"># 3</span></span><br></pre></td></tr></table></figure>这个简单的函数可以使用 <code>lambda</code> 匿名函数完成。<code>lambda</code> 匿名函数必须在一行内完成,语法为:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">lambda</span> arguments: expression</span><br></pre></td></tr></table></figure>那么上面的两数相加的函数也可以写成:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">add = <span class="keyword">lambda</span> x, y: x + y</span><br><span class="line">add(<span class="number">1</span>, <span class="number">2</span>) <span class="comment"># 3</span></span><br></pre></td></tr></table></figure><h1 id="一行代码生成一个容器:解析式"><a href="#一行代码生成一个容器:解析式" class="headerlink" title="一行代码生成一个容器:解析式"></a>一行代码生成一个容器:解析式</h1><h2 id="基础列表解析式"><a href="#基础列表解析式" class="headerlink" title="基础列表解析式"></a>基础列表解析式</h2>假如我们希望写一个函数将一个列表中的数字平方在返回一个新列表:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">square</span><span class="params">(lst)</span>:</span></span><br><span class="line"> new_lst = []</span><br><span class="line"> <span class="keyword">for</span> num <span class="keyword">in</span> lst:</span><br><span class="line"> new_lst.append(num**<span class="number">2</span>)</span><br><span class="line"> <span class="keyword">return</span> new_lst</span><br></pre></td></tr></table></figure>这么一个简单的函数居然要用 5 行代码,尴尬症都要犯了。解析式来救场!解析式又分列表解析式、集合解析式和字典解析式,以列表解析式为例,语法为:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">[expression <span class="keyword">for</span> var <span class="keyword">in</span> iterable <span class="keyword">if</span> condition]</span><br></pre></td></tr></table></figure>所以我们可以将上面的函数改写为:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">[var**<span class="number">2</span> <span class="keyword">for</span> var <span class="keyword">in</span> lst]</span><br></pre></td></tr></table></figure>一行搞定!要注意的是,只有当元素符合判定条件(<code>if</code>)时才会被处理,解析式里没有 <code>else</code>。<h2 id="嵌套列表解析式"><a href="#嵌套列表解析式" class="headerlink" title="嵌套列表解析式"></a><a href="https://spapas.github.io/2016/04/27/python-nested-list-comprehensions/" target="_blank" rel="noopener" title="嵌套列表解析式">嵌套列表解析式</a></h2>更复杂的列表解析式包含多个变量:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">[expression <span class="keyword">for</span> var1 <span class="keyword">in</span> iterable1 <span class="keyword">if</span> condition1 <span class="keyword">for</span> var2 <span class="keyword">in</span> iterable2 <span class="keyword">if</span> condition2]</span><br></pre></td></tr></table></figure>我们应该怎么写呢?根据解析式的 PEP202 文档,<blockquote><p>It is proposed to allow conditional construction of list literals using for and if clauses. <em>They would nest in the same way for loops and if statements nest now.</em></p></blockquote></li></ol><p>如果我们想写一个 <code>for</code> 循环来做同样的事情,我们可能这样写:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> var1 <span class="keyword">in</span> iterable1:</span><br><span class="line"> <span class="keyword">if</span> condition1:</span><br><span class="line"> <span class="keyword">for</span> var2 <span class="keyword">in</span> iterable2:</span><br><span class="line"> <span class="keyword">if</span> condition2:</span><br><span class="line"> expression</span><br></pre></td></tr></table></figure><p>使用解析式,我们只需要将正常写法最后的表达式写在最前面,之后依次把嵌套循环的代码依次写在同一行就行了。我们来看一个终极例子:</p><blockquote><p>有一个嵌套单词列表,如果单词有至少两个字母,则返回一个包含所有字母与它所在单词的列表的索引的新列表。</p></blockquote><p>如果我们用常规方法,可能会这么写:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">new_lst = []</span><br><span class="line"><span class="keyword">for</span> idx, lst <span class="keyword">in</span> enumerate(nested_lst):</span><br><span class="line"> <span class="keyword">for</span> word <span class="keyword">in</span> lst:</span><br><span class="line"> <span class="keyword">if</span> len(word) >= <span class="number">2</span>:</span><br><span class="line"> <span class="keyword">for</span> letter <span class="keyword">in</span> word:</span><br><span class="line"> new_lst.append((letter, idx))</span><br></pre></td></tr></table></figure><p>使用列表解析式可以这么写:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">strings = [ [<span class="string">'foo'</span>, <span class="string">'bar'</span>], [<span class="string">'baz'</span>, <span class="string">'taz'</span>], [<span class="string">'w'</span>, <span class="string">'koko'</span>] ]</span><br><span class="line">[ (letter, idx) <span class="keyword">for</span> idx, lst <span class="keyword">in</span> enumerate(strings) <span class="keyword">for</span> word <span class="keyword">in</span> lst <span class="keyword">if</span> len(word)><span class="number">2</span> <span class="keyword">for</span> letter <span class="keyword">in</span> word]</span><br><span class="line"><span class="comment"># [('f', 0), ('o', 0), ('o', 0), ('b', 0), ('a', 0), ('r', 0), ('b', 1), ('a', 1), ('z', 1), ('t', 1), ('a', 1), ('z', 1), ('k', 2), ('o', 2), ('k', 2), ('o', 2)]</span></span><br></pre></td></tr></table></figure><h2 id="集合解析式与字典解析式"><a href="#集合解析式与字典解析式" class="headerlink" title="集合解析式与字典解析式"></a>集合解析式与字典解析式</h2><p>集合解析式与列表解析式差不多,区别仅仅是将 <code>[]</code> 换成了 <code>{}</code>。字典解析式与列表解析式的区别在于将 <code>var</code> 换成了键值对<code>key, val</code>。比如在 NLP 应用中,需要生成一个词与索引的 lookup 字典:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">id2word = {idx:word <span class="keyword">for</span> (idx, word) <span class="keyword">in</span> enumerate(corpus)}</span><br><span class="line">word2id = {word:id <span class="keyword">for</span> (id, word) <span class="keyword">in</span> id2word.items()}</span><br></pre></td></tr></table></figure><h1 id="替代-yield-语句:生成器"><a href="#替代-yield-语句:生成器" class="headerlink" title="替代 yield 语句:生成器"></a>替代 <code>yield</code> 语句:生成器</h1><p><code>yield</code> 语句是一种一边循环一边计算的机制,将一个函数中的 <code>return</code> 替换成 <code>yield</code>,使用 <code>next</code> 调用该函数:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">generator</span><span class="params">(x)</span>:</span></span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> range(x):</span><br><span class="line"> <span class="keyword">yield</span> i</span><br><span class="line">x = generator(<span class="number">5</span>)</span><br><span class="line">x <span class="comment"># <generator object generator at 0x7fee19239c80></span></span><br><span class="line">next(x) <span class="comment"># 0</span></span><br><span class="line">next(x) <span class="comment"># 1</span></span><br><span class="line">...</span><br></pre></td></tr></table></figure><p>也可以使用一行代码将一个简单的 <code>yield</code> 语句实现,这就是生成器。生成器与列表解析式的语法一样,只是用 <code>()</code> 代替了 <code>[]</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">x = (i <span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">5</span>))</span><br><span class="line">x</span><br><span class="line">next(x) <span class="comment"># 0</span></span><br><span class="line">...</span><br></pre></td></tr></table></figure><hr><p>本文中的技巧仅可用于一些简单的任务,如果任务比较复杂,这些技巧要么无法完成,要么可读性大幅下降而变得难以理解与维护。</p>]]></content>
<tags>
<tag> 经验总结 </tag>
</tags>
</entry>
<entry>
<title>[DSU&阿里云天池] Python训练营 Task3</title>
<link href="2020/12/24/DSU-%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-Python%E8%AE%AD%E7%BB%83%E8%90%A5-Task3/"/>
<url>2020/12/24/DSU-%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-Python%E8%AE%AD%E7%BB%83%E8%90%A5-Task3/</url>
<content type="html"><![CDATA[<h1 id="函数与-Lambda-表达式"><a href="#函数与-Lambda-表达式" class="headerlink" title="函数与 Lambda 表达式"></a>函数与 Lambda 表达式</h1><h2 id="函数"><a href="#函数" class="headerlink" title="函数"></a>函数</h2><h3 id="函数的定义"><a href="#函数的定义" class="headerlink" title="函数的定义"></a>函数的定义</h3><ol><li>函数以 <code>def</code> 关键词开头,后接函数名和圆括号 <code>()</code>。</li><li>函数执行的代码以冒号起始,并且缩进。</li><li><code>return [表达式]</code> 结束函数,选择性地返回一个值给调用方。不带表达式的 return 相当于返回 <code>None</code> 。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">functionname</span><span class="params">(parameters)</span>:</span></span><br><span class="line"> <span class="string">"函数_文档字符串"</span></span><br><span class="line"> function_suite</span><br><span class="line"> <span class="keyword">return</span> [expression]</span><br></pre></td></tr></table></figure><a id="more"></a><h3 id="函数的调用"><a href="#函数的调用" class="headerlink" title="函数的调用"></a>函数的调用</h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">printme</span><span class="params">(str)</span>:</span></span><br><span class="line"> print(str)</span><br><span class="line">printme(<span class="string">"我要调用用户自定义函数!"</span>) <span class="comment"># 我要调用用户自定义函数! </span></span><br><span class="line">printme(<span class="string">"再次调用同一函数"</span>) <span class="comment"># 再次调用同一函数</span></span><br><span class="line">temp = printme(<span class="string">'hello'</span>) <span class="comment"># hello</span></span><br><span class="line">print(temp) <span class="comment"># None</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">add</span><span class="params">(a, b)</span>:</span></span><br><span class="line"> print(a + b)</span><br><span class="line">add(<span class="number">1</span>, <span class="number">2</span>) <span class="comment"># 3</span></span><br><span class="line">add([<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>], [<span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>]) <span class="comment"># [1, 2, 3, 4, 5, 6]</span></span><br></pre></td></tr></table></figure><h3 id="函数文档"><a href="#函数文档" class="headerlink" title="函数文档"></a>函数文档</h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">MyFirstFunction</span><span class="params">(name)</span>:</span></span><br><span class="line"> <span class="string">"函数定义过程中 name 是形参"</span> </span><br><span class="line"> <span class="comment"># 因为 Ta 只是一个形式,表示占据一个参数位置 </span></span><br><span class="line"> print(<span class="string">'传递进来的 {0} 叫做实参,因为 Ta 是具体的参数值!'</span>.format(name))</span><br><span class="line"> </span><br><span class="line">MyFirstFunction(<span class="string">'老马的程序人生'</span>)</span><br><span class="line"><span class="comment"># 传递进来的 老马的程序人生 叫做实参,因为Ta是具体的参数值!</span></span><br><span class="line"></span><br><span class="line">print(MyFirstFunction.__doc__) </span><br><span class="line"><span class="comment"># 函数定义过程中name是形参</span></span><br><span class="line"></span><br><span class="line">help(MyFirstFunction)</span><br><span class="line"><span class="comment"># Help on function MyFirstFunction in module __main__: </span></span><br><span class="line"><span class="comment"># MyFirstFunction(name)</span></span><br><span class="line"><span class="comment"># 函数定义过程中name是形参</span></span><br></pre></td></tr></table></figure><h3 id="函数参数"><a href="#函数参数" class="headerlink" title="函数参数"></a>函数参数</h3></li><li>位置参数<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">functionname</span><span class="params">(arg1)</span>:</span> </span><br><span class="line"> <span class="string">"函数_文档字符串"</span></span><br><span class="line"> function_suite</span><br><span class="line"> <span class="keyword">return</span> [expression]</span><br></pre></td></tr></table></figure><code>arg1</code> - 位置参数 ,这些参数在调用函数 (call function) 时位置要固定。</li><li>默认参数<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">functionname</span><span class="params">(arg1, arg2=v)</span>:</span> </span><br><span class="line"> <span class="string">"函数_文档字符串"</span></span><br><span class="line"> function_suite</span><br><span class="line"> <span class="keyword">return</span> [expression]</span><br></pre></td></tr></table></figure></li></ol><ul><li><code>arg2 = v</code> - 默认参数 = 默认值,调用函数时,默认参数的值如果没有传入,则被认为是默认值。</li><li>默认参数一定要放在位置参数后面,不然程序会报错。</li><li>Python 允许函数调用时参数的顺序与声明时不一致,因为 Python 解释器能够用参数名匹配参数值。</li></ul><ol start="3"><li>可变参数<br>可变参数就是传入的参数个数是可变的,可以是 0, 1, 2 到任意个,是不定长的参数。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">functionname</span><span class="params">(arg1, arg2=v, *args)</span>:</span> <span class="string">"函数_文档字符串"</span></span><br><span class="line"> function_suite</span><br><span class="line"> <span class="keyword">return</span> [expression]</span><br></pre></td></tr></table></figure></li></ol><ul><li><code>*args</code> - 可变参数,可以是从零个到任意个,自动组装成元组。 </li><li>加了星号(*)的变量名会存放所有未命名的变量参数。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">printinfo</span><span class="params">(arg1, *args)</span>:</span></span><br><span class="line"> print(arg1)</span><br><span class="line"> <span class="keyword">for</span> var <span class="keyword">in</span> args:</span><br><span class="line"> print(var)</span><br><span class="line">printinfo(<span class="number">10</span>) <span class="comment"># 10</span></span><br><span class="line">printinfo(<span class="number">70</span>, <span class="number">60</span>, <span class="number">50</span>)</span><br><span class="line"><span class="comment"># 70 </span></span><br><span class="line"><span class="comment"># 60 </span></span><br><span class="line"><span class="comment"># 50</span></span><br></pre></td></tr></table></figure></li></ul><ol start="4"><li>关键字参数<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">functionname</span><span class="params">(arg1, arg2=v, *args, **kw)</span>:</span> </span><br><span class="line"> <span class="string">"函数_文档字符串"</span></span><br><span class="line"> function_suite</span><br><span class="line"> <span class="keyword">return</span> [expression]</span><br></pre></td></tr></table></figure><code>**kw</code> - 关键字参数,可以是从零个到任意个,自动组装成字典。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">printinfo</span><span class="params">(arg1, *args, **kwargs)</span>:</span></span><br><span class="line"> print(arg1)</span><br><span class="line"> print(args)</span><br><span class="line"> print(kwargs)</span><br><span class="line">printinfo(<span class="number">70</span>, <span class="number">60</span>, <span class="number">50</span>)</span><br><span class="line"><span class="comment"># 70</span></span><br><span class="line"><span class="comment"># (60, 50)</span></span><br><span class="line"><span class="comment"># {}</span></span><br><span class="line">printinfo(<span class="number">70</span>, <span class="number">60</span>, <span class="number">50</span>, a=<span class="number">1</span>, b=<span class="number">2</span>)</span><br><span class="line"><span class="comment"># 70</span></span><br><span class="line"><span class="comment"># (60, 50)</span></span><br><span class="line"><span class="comment"># {'a': 1, 'b': 2}</span></span><br></pre></td></tr></table></figure></li><li>命名关键字参数<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">functionname</span><span class="params">(arg1, arg2=v, *args, *, nkw, **kw)</span>:</span></span><br><span class="line"> <span class="string">"函数_文档字符串"</span></span><br><span class="line"> function_suite</span><br><span class="line"> <span class="keyword">return</span> [expression]</span><br></pre></td></tr></table></figure></li></ol><ul><li><code>*, nkw</code> - 命名关键字参数,用户想要输入的关键字参数,定义方式是在nkw 前面加个分隔符 <code>*</code>。</li><li>如果要限制关键字参数的名字,就可以用「命名关键字参数」</li><li>使用命名关键字参数时,要特别注意不能缺少参数名。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">printinfo</span><span class="params">(arg1, *, nkw, **kwargs)</span>:</span></span><br><span class="line"> print(arg1)</span><br><span class="line"> print(nkw)</span><br><span class="line"> print(kwargs)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">printinfo(<span class="number">70</span>, nkw=<span class="number">10</span>, a=<span class="number">1</span>, b=<span class="number">2</span>)</span><br><span class="line"><span class="comment"># 70</span></span><br><span class="line"><span class="comment"># 10</span></span><br><span class="line"><span class="comment"># {'a': 1, 'b': 2}</span></span><br><span class="line"></span><br><span class="line">printinfo(<span class="number">70</span>, <span class="number">10</span>, a=<span class="number">1</span>, b=<span class="number">2</span>)</span><br><span class="line"><span class="comment"># TypeError: printinfo() takes 1 positional argument but 2 were given</span></span><br></pre></td></tr></table></figure></li><li>没有写参数名 <code>nwk</code>,因此 <code>10</code> 被当成「位置参数」,而原函数只有 1 个位置函数,现在调用了 2 个,因此程序会报错。<h3 id="参数组合"><a href="#参数组合" class="headerlink" title="参数组合"></a>参数组合</h3>在 Python 中定义函数,可以用位置参数、默认参数、可变参数、命名关键字参数和关键字参数,这 5 种参数中的 4 个都可以一起使用,但是注意,参数定义的顺序必须是:</li><li>位置参数、默认参数、可变参数和关键字参数。</li><li>位置参数、默认参数、命名关键字参数和关键字参数。</li></ul><p>要注意定义可变参数和关键字参数的语法:</p><ul><li><code>*args</code> 是可变参数,<code>args</code> 接收的是一个 <code>tuple</code></li><li><code>**kw</code> 是关键字参数,<code>kw</code> 接收的是一个 <code>dict</code></li></ul><p>命名关键字参数是为了限制调用者可以传入的参数名,同时可以提供默认值。定义命名关键字参数不要忘了写分隔符 <code>*</code>,否则定义的是位置参数。</p><h2 id="函数的返回值"><a href="#函数的返回值" class="headerlink" title="函数的返回值"></a>函数的返回值</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">back</span><span class="params">()</span>:</span></span><br><span class="line"> <span class="keyword">return</span> [<span class="number">1</span>, <span class="string">'小马的程序人生'</span>, <span class="number">3.14</span>]</span><br><span class="line">print(back()) <span class="comment"># [1, '小马的程序人生', 3.14]</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">back</span><span class="params">()</span>:</span></span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span>, <span class="string">'小马的程序人生'</span>, <span class="number">3.14</span></span><br><span class="line">print(back()) <span class="comment"># (1, '小马的程序人生', 3.14)</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">printme</span><span class="params">(str)</span>:</span></span><br><span class="line"> print(str)</span><br><span class="line">temp = printme(<span class="string">'hello'</span>) <span class="comment"># hello</span></span><br><span class="line">print(temp) <span class="comment"># None</span></span><br><span class="line">print(type(temp)) <span class="comment"># <class 'NoneType'></span></span><br></pre></td></tr></table></figure><h2 id="变量的作用域"><a href="#变量的作用域" class="headerlink" title="变量的作用域"></a>变量的作用域</h2><ul><li>Python 中,程序的变量并不是在哪个位置都可以访问的,访问权限决定于这个变量是在哪里赋值的。</li><li>定义在函数内部的变量拥有局部作用域,该变量称为局部变量。</li><li>定义在函数外部的变量拥有全局作用域,该变量称为全局变量。</li><li>局部变量只能在其被声明的函数内部访问,而全局变量可以在整个程序范围内访问。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">discounts</span><span class="params">(price, rate)</span>:</span></span><br><span class="line"> final_price = price * rate</span><br><span class="line"> <span class="keyword">return</span> final_price</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">old_price = float(input(<span class="string">'请输入原价:'</span>)) <span class="comment"># 98</span></span><br><span class="line">rate = float(input(<span class="string">'请输入折扣率:'</span>)) <span class="comment"># 0.9</span></span><br><span class="line">new_price = discounts(old_price, rate)</span><br><span class="line">print(<span class="string">'打折后价格是:%.2f'</span> % new_price) <span class="comment"># 88.20</span></span><br></pre></td></tr></table></figure></li><li>当内部作用域想修改外部作用域的变量时,就要用到 <code>global</code> 和 <code>nonlocal</code> 关键字了。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">num = <span class="number">1</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">fun1</span><span class="params">()</span>:</span></span><br><span class="line"> <span class="keyword">global</span> num <span class="comment"># 需要使用 global 关键字声明</span></span><br><span class="line"> print(num) <span class="comment"># 1</span></span><br><span class="line"> num = <span class="number">123</span></span><br><span class="line"> print(num) <span class="comment"># 123</span></span><br><span class="line"></span><br><span class="line">fun1()</span><br><span class="line">print(num) <span class="comment"># 123</span></span><br></pre></td></tr></table></figure><h3 id="内嵌函数"><a href="#内嵌函数" class="headerlink" title="内嵌函数"></a>内嵌函数</h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">outer</span><span class="params">()</span>:</span></span><br><span class="line"> print(<span class="string">'outer函数在这被调用'</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">inner</span><span class="params">()</span>:</span></span><br><span class="line"> print(<span class="string">'inner函数在这被调用'</span>)</span><br><span class="line"></span><br><span class="line"> inner() <span class="comment"># 该函数只能在outer函数内部被调用</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">outer()</span><br><span class="line"><span class="comment"># outer函数在这被调用</span></span><br><span class="line"><span class="comment"># inner函数在这被调用</span></span><br></pre></td></tr></table></figure><h3 id="闭包"><a href="#闭包" class="headerlink" title="闭包"></a>闭包</h3></li><li>是函数式编程的一个重要的语法结构,是一种特殊的内嵌函数。</li><li>如果在一个内部函数里对外层非全局作用域的变量进行引用,那么内部函数就被认为是闭包。</li><li>通过闭包可以访问外层非全局作用域的变量,这个作用域称为闭包作用域。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">funX</span><span class="params">(x)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">funY</span><span class="params">(y)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> x * y</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> funY</span><br><span class="line"></span><br><span class="line">i = funX(<span class="number">8</span>)</span><br><span class="line">print(type(i)) <span class="comment"># <class 'function'></span></span><br><span class="line">print(i(<span class="number">5</span>)) <span class="comment"># 40</span></span><br></pre></td></tr></table></figure>闭包的返回值通常是函数。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">make_counter</span><span class="params">(init)</span>:</span></span><br><span class="line"> counter = [init]</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">inc</span><span class="params">()</span>:</span> counter[<span class="number">0</span>] += <span class="number">1</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">dec</span><span class="params">()</span>:</span> counter[<span class="number">0</span>] -= <span class="number">1</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">get</span><span class="params">()</span>:</span> <span class="keyword">return</span> counter[<span class="number">0</span>]</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">reset</span><span class="params">()</span>:</span> counter[<span class="number">0</span>] = init</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> inc, dec, get, reset</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">inc, dec, get, reset = make_counter(<span class="number">0</span>)</span><br><span class="line">inc()</span><br><span class="line">inc()</span><br><span class="line">inc()</span><br><span class="line">print(get()) <span class="comment"># 3</span></span><br><span class="line">dec()</span><br><span class="line">print(get()) <span class="comment"># 2</span></span><br><span class="line">reset()</span><br><span class="line">print(get()) <span class="comment"># 0</span></span><br></pre></td></tr></table></figure>如果要修改闭包作用域中的变量则需要 <code>nonlocal</code> 关键字<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">outer</span><span class="params">()</span>:</span></span><br><span class="line"> num = <span class="number">10</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">inner</span><span class="params">()</span>:</span></span><br><span class="line"> <span class="keyword">nonlocal</span> num <span class="comment"># nonlocal关键字声明</span></span><br><span class="line"> num = <span class="number">100</span></span><br><span class="line"> print(num)</span><br><span class="line"></span><br><span class="line"> inner()</span><br><span class="line"> print(num)</span><br><span class="line"></span><br><span class="line">outer()</span><br><span class="line"><span class="comment"># 100</span></span><br><span class="line"><span class="comment"># 100</span></span><br></pre></td></tr></table></figure><h3 id="递归"><a href="#递归" class="headerlink" title="递归"></a>递归</h3>如果一个函数在内部调用自身本身,这个函数就是递归函数。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 利用循环</span></span><br><span class="line">n = <span class="number">5</span></span><br><span class="line"><span class="keyword">for</span> k <span class="keyword">in</span> range(<span class="number">1</span>, <span class="number">5</span>):</span><br><span class="line"> n = n * k</span><br><span class="line">print(n) <span class="comment"># 120</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 利用递归</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">factorial</span><span class="params">(n)</span>:</span></span><br><span class="line"> <span class="keyword">if</span> n == <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">return</span> <span class="number">1</span></span><br><span class="line"> <span class="keyword">return</span> n * factorial(n - <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line">print(factorial(<span class="number">5</span>)) <span class="comment"># 120</span></span><br></pre></td></tr></table></figure>设置递归的层数,Python默认递归层数为 <code>100</code><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> sys</span><br><span class="line">sys.setrecursionlimit(<span class="number">1000</span>)</span><br></pre></td></tr></table></figure><h2 id="Lambda-表达式"><a href="#Lambda-表达式" class="headerlink" title="Lambda 表达式"></a><code>Lambda</code> 表达式</h2><h3 id="匿名函数的定义"><a href="#匿名函数的定义" class="headerlink" title="匿名函数的定义"></a>匿名函数的定义</h3>在 Python 里有两类函数:</li><li>第一类:用 def 关键词定义的正规函数</li><li>第二类:用 lambda 关键词定义的匿名函数</li></ul><p>Python 使用 lambda 关键词来创建匿名函数,而非def关键词,它没有函数名,其语法结构下:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">lambda</span> argument_list: expression</span><br></pre></td></tr></table></figure><ul><li><code>lambda</code> - 定义匿名函数的关键词。</li><li><code>argument_list</code> - 函数参数,它们可以是位置参数、默认参数、关键字参数,和正规函数里的参数类型一样。</li><li><code>:</code> - 冒号,在函数参数和表达式中间要加个冒号。</li><li><code>expression</code> - 只是一个表达式,输入函数参数,输出一些值。</li></ul><p>注意:</p><ul><li><p><code>expression</code> 中没有 return 语句,因为 lambda 不需要它来返回,表达式本身结果就是返回值。</p></li><li><p>匿名函数拥有自己的命名空间,且不能访问自己参数列表之外或全局命名空间里的参数。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">sqr</span><span class="params">(x)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> x ** <span class="number">2</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">print(sqr)</span><br><span class="line"><span class="comment"># <function sqr at 0x000000BABD3A4400></span></span><br><span class="line"></span><br><span class="line">y = [sqr(x) <span class="keyword">for</span> x <span class="keyword">in</span> range(<span class="number">10</span>)]</span><br><span class="line">print(y)</span><br><span class="line"><span class="comment"># [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]</span></span><br><span class="line"></span><br><span class="line">lbd_sqr = <span class="keyword">lambda</span> x: x ** <span class="number">2</span></span><br><span class="line">print(lbd_sqr)</span><br><span class="line"><span class="comment"># <function <lambda> at 0x000000BABB6AC1E0></span></span><br><span class="line"></span><br><span class="line">y = [lbd_sqr(x) <span class="keyword">for</span> x <span class="keyword">in</span> range(<span class="number">10</span>)]</span><br><span class="line">print(y)</span><br><span class="line"><span class="comment"># [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">sumary = <span class="keyword">lambda</span> arg1, arg2: arg1 + arg2</span><br><span class="line">print(sumary(<span class="number">10</span>, <span class="number">20</span>)) <span class="comment"># 30</span></span><br><span class="line"></span><br><span class="line">func = <span class="keyword">lambda</span> *args: sum(args)</span><br><span class="line">print(func(<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>)) <span class="comment"># 15</span></span><br></pre></td></tr></table></figure><h3 id="匿名函数的应用"><a href="#匿名函数的应用" class="headerlink" title="匿名函数的应用"></a>匿名函数的应用</h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">f</span><span class="params">(x)</span>:</span></span><br><span class="line"> y = []</span><br><span class="line"> <span class="keyword">for</span> item <span class="keyword">in</span> x:</span><br><span class="line"> y.append(item + <span class="number">10</span>)</span><br><span class="line"> <span class="keyword">return</span> y</span><br><span class="line"></span><br><span class="line">x = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>]</span><br><span class="line">f(x)</span><br><span class="line">print(x)</span><br><span class="line"><span class="comment"># [1, 2, 3]</span></span><br></pre></td></tr></table></figure><p>匿名函数 常常应用于函数式编程的高阶函数 (high-order function)中,主要有两种形式:</p></li><li><p>参数是函数 (filter, map)</p></li><li><p>返回值是函数 (closure)<br>如,在 <code>filter</code> 和 <code>map</code> 函数中的应用:</p></li><li><p><code>filter(function, iterable)</code> 过滤序列,过滤掉不符合条件的元素,返回一个迭代器对象,如果要转换为列表,可以使用 <code>list()</code> 来转换。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">odd = <span class="keyword">lambda</span> x: x % <span class="number">2</span> == <span class="number">1</span></span><br><span class="line">templist = filter(odd, [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>, <span class="number">7</span>, <span class="number">8</span>, <span class="number">9</span>])</span><br><span class="line">print(list(templist)) <span class="comment"># [1, 3, 5, 7, 9]</span></span><br></pre></td></tr></table></figure></li><li><p><code>map(function, *iterables)</code> 根据提供的函数对指定序列做映射。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">m1 = map(<span class="keyword">lambda</span> x: x ** <span class="number">2</span>, [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>])</span><br><span class="line">print(list(m1)) </span><br><span class="line"><span class="comment"># [1, 4, 9, 16, 25]</span></span><br><span class="line"></span><br><span class="line">m2 = map(<span class="keyword">lambda</span> x, y: x + y, [<span class="number">1</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">7</span>, <span class="number">9</span>], [<span class="number">2</span>, <span class="number">4</span>, <span class="number">6</span>, <span class="number">8</span>, <span class="number">10</span>])</span><br><span class="line">print(list(m2)) </span><br><span class="line"><span class="comment"># [3, 7, 11, 15, 19]</span></span><br></pre></td></tr></table></figure><h1 id="类与对象"><a href="#类与对象" class="headerlink" title="类与对象"></a>类与对象</h1><h2 id="对象-属性-方法"><a href="#对象-属性-方法" class="headerlink" title="对象 = 属性 + 方法"></a>对象 = 属性 + 方法</h2><p>对象是类的实例。换句话说,类主要定义对象的结构,然后我们以类为模板创建对象。类不但包含方法定义,而且还包含所有实例共享的数据。</p></li><li><p>封装:信息隐蔽技术</p></li></ul><p>我们可以使用关键字 <code>class</code> 定义 Python 类,关键字后面紧跟类的名称、分号和类的实现。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Turtle</span>:</span> <span class="comment"># Python中的类名约定以大写字母开头</span></span><br><span class="line"> <span class="string">"""关于类的一个简单例子"""</span></span><br><span class="line"> <span class="comment"># 属性</span></span><br><span class="line"> color = <span class="string">'green'</span></span><br><span class="line"> weight = <span class="number">10</span></span><br><span class="line"> legs = <span class="number">4</span></span><br><span class="line"> shell = <span class="literal">True</span></span><br><span class="line"> mouth = <span class="string">'大嘴'</span></span><br><span class="line"></span><br><span class="line"> <span class="comment"># 方法</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">climb</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'我正在很努力的向前爬...'</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">run</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'我正在飞快的向前跑...'</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">bite</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'咬死你咬死你!!'</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">eat</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'有得吃,真满足...'</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">sleep</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'困了,睡了,晚安,zzz'</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">tt = Turtle()</span><br><span class="line">print(tt)</span><br><span class="line"><span class="comment"># <__main__.Turtle object at 0x0000007C32D67F98></span></span><br><span class="line"></span><br><span class="line">print(type(tt))</span><br><span class="line"><span class="comment"># <class '__main__.Turtle'></span></span><br><span class="line"></span><br><span class="line">print(tt.__class__)</span><br><span class="line"><span class="comment"># <class '__main__.Turtle'></span></span><br><span class="line"></span><br><span class="line">print(tt.__class__.__name__)</span><br><span class="line"><span class="comment"># Turtle</span></span><br><span class="line"></span><br><span class="line">tt.climb()</span><br><span class="line"><span class="comment"># 我正在很努力的向前爬...</span></span><br><span class="line"></span><br><span class="line">tt.run()</span><br><span class="line"><span class="comment"># 我正在飞快的向前跑...</span></span><br><span class="line"></span><br><span class="line">tt.bite()</span><br><span class="line"><span class="comment"># 咬死你咬死你!!</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># Python类也是对象。它们是type的实例</span></span><br><span class="line">print(type(Turtle))</span><br><span class="line"><span class="comment"># <class 'type'></span></span><br></pre></td></tr></table></figure><ul><li>继承:子类自动共享父类之间数据和方法的机制<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">MyList</span><span class="params">(list)</span>:</span></span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line">lst = MyList([<span class="number">1</span>, <span class="number">5</span>, <span class="number">2</span>, <span class="number">7</span>, <span class="number">8</span>])</span><br><span class="line">lst.append(<span class="number">9</span>)</span><br><span class="line">lst.sort()</span><br><span class="line">print(lst)</span><br><span class="line"><span class="comment"># [1, 2, 5, 7, 8, 9]</span></span><br></pre></td></tr></table></figure></li><li>多态:不同对象对同一方法响应不同的行动<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Animal</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">run</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">raise</span> AttributeError(<span class="string">'子类必须实现这个方法'</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">People</span><span class="params">(Animal)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">run</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'人正在走'</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Pig</span><span class="params">(Animal)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">run</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'pig is walking'</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Dog</span><span class="params">(Animal)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">run</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'dog is running'</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">func</span><span class="params">(animal)</span>:</span></span><br><span class="line"> animal.run()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">func(Pig())</span><br><span class="line"><span class="comment"># pig is walking</span></span><br></pre></td></tr></table></figure><h2 id="self"><a href="#self" class="headerlink" title="self"></a>self</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Test</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">prt</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(self)</span><br><span class="line"> print(self.__class__)</span><br><span class="line"></span><br><span class="line">t = Test()</span><br><span class="line">t.prt()</span><br><span class="line"><span class="comment"># <__main__.Test object at 0x000000BC5A351208></span></span><br><span class="line"><span class="comment"># <class '__main__.Test'></span></span><br></pre></td></tr></table></figure>类的方法与普通的函数只有一个特别的区别 —— 它们必须有一个额外的第一个参数名称(对应于该实例,即该对象本身),按照惯例它的名称是 <code>self</code>。在调用方法时,我们无需明确提供与参数 <code>self</code> 相对应的参数。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Ball</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">setName</span><span class="params">(self, name)</span>:</span></span><br><span class="line"> self.name = name</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">kick</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">"我叫%s,该死的,谁踢我..."</span> % self.name)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">a = Ball()</span><br><span class="line">a.setName(<span class="string">"球A"</span>)</span><br><span class="line">b = Ball()</span><br><span class="line">b.setName(<span class="string">"球B"</span>)</span><br><span class="line">c = Ball()</span><br><span class="line">c.setName(<span class="string">"球C"</span>)</span><br><span class="line">a.kick()</span><br><span class="line"><span class="comment"># 我叫球A,该死的,谁踢我...</span></span><br><span class="line">b.kick()</span><br><span class="line"><span class="comment"># 我叫球B,该死的,谁踢我...</span></span><br></pre></td></tr></table></figure><h2 id="魔术方法"><a href="#魔术方法" class="headerlink" title="魔术方法"></a>魔术方法</h2>类有一个名为 <code>__init__(self[, param1, param2...])</code> 的魔法方法,该方法在类实例化时会自动调用。<h2 id="公有和私有"><a href="#公有和私有" class="headerlink" title="公有和私有"></a>公有和私有</h2>在 Python 中定义私有变量只需要在变量名或函数名前加上 <code>__</code> 两个下划线,那么这个函数或变量就会为私有的了。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">JustCounter</span>:</span></span><br><span class="line"> __secretCount = <span class="number">0</span> <span class="comment"># 私有变量</span></span><br><span class="line"> publicCount = <span class="number">0</span> <span class="comment"># 公开变量</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">count</span><span class="params">(self)</span>:</span></span><br><span class="line"> self.__secretCount += <span class="number">1</span></span><br><span class="line"> self.publicCount += <span class="number">1</span></span><br><span class="line"> print(self.__secretCount)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">counter = JustCounter()</span><br><span class="line">counter.count() <span class="comment"># 1</span></span><br><span class="line">counter.count() <span class="comment"># 2</span></span><br><span class="line">print(counter.publicCount) <span class="comment"># 2</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># Python的私有为伪私有</span></span><br><span class="line">print(counter._JustCounter__secretCount) <span class="comment"># 2 </span></span><br><span class="line">print(counter.__secretCount) </span><br><span class="line"><span class="comment"># AttributeError: 'JustCounter' object has no attribute '__secretCount'</span></span><br></pre></td></tr></table></figure>类的私有方法实例<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Site</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, name, url)</span>:</span></span><br><span class="line"> self.name = name <span class="comment"># public</span></span><br><span class="line"> self.__url = url <span class="comment"># private</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">who</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'name : '</span>, self.name)</span><br><span class="line"> print(<span class="string">'url : '</span>, self.__url)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__foo</span><span class="params">(self)</span>:</span> <span class="comment"># 私有方法</span></span><br><span class="line"> print(<span class="string">'这是私有方法'</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">foo</span><span class="params">(self)</span>:</span> <span class="comment"># 公共方法</span></span><br><span class="line"> print(<span class="string">'这是公共方法'</span>)</span><br><span class="line"> self.__foo()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">x = Site(<span class="string">'老马的程序人生'</span>, <span class="string">'https://blog.csdn.net/LSGO_MYP'</span>)</span><br><span class="line">x.who()</span><br><span class="line"><span class="comment"># name : 老马的程序人生</span></span><br><span class="line"><span class="comment"># url : https://blog.csdn.net/LSGO_MYP</span></span><br><span class="line"></span><br><span class="line">x.foo()</span><br><span class="line"><span class="comment"># 这是公共方法</span></span><br><span class="line"><span class="comment"># 这是私有方法</span></span><br><span class="line"></span><br><span class="line">x.__foo()</span><br><span class="line"><span class="comment"># AttributeError: 'Site' object has no attribute '__foo'</span></span><br></pre></td></tr></table></figure><h2 id="继承"><a href="#继承" class="headerlink" title="继承"></a>继承</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">DerivedClassName</span><span class="params">(BaseClassName)</span>:</span></span><br><span class="line"> statement<span class="number">-1</span></span><br><span class="line"> .</span><br><span class="line"> .</span><br><span class="line"> .</span><br><span class="line"> statement-N</span><br></pre></td></tr></table></figure>BaseClassName(基类名)必须与派生类定义在一个作用域内。除了类,还可以用表达式,基类定义在另一个模块中时这一点非常有用。如果子类中定义与父类同名的方法或属性,则会自动覆盖父类对应的方法或属性。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 类定义</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">people</span>:</span></span><br><span class="line"> <span class="comment"># 定义基本属性</span></span><br><span class="line"> name = <span class="string">''</span></span><br><span class="line"> age = <span class="number">0</span></span><br><span class="line"> <span class="comment"># 定义私有属性,私有属性在类外部无法直接进行访问</span></span><br><span class="line"> __weight = <span class="number">0</span></span><br><span class="line"></span><br><span class="line"> <span class="comment"># 定义构造方法</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, n, a, w)</span>:</span></span><br><span class="line"> self.name = n</span><br><span class="line"> self.age = a</span><br><span class="line"> self.__weight = w</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">speak</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">"%s 说: 我 %d 岁。"</span> % (self.name, self.age))</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 单继承示例</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">student</span><span class="params">(people)</span>:</span></span><br><span class="line"> grade = <span class="string">''</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, n, a, w, g)</span>:</span></span><br><span class="line"> <span class="comment"># 调用父类的构函</span></span><br><span class="line"> people.__init__(self, n, a, w)</span><br><span class="line"> self.grade = g</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 覆写父类的方法</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">speak</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">"%s 说: 我 %d 岁了,我在读 %d 年级"</span> % (self.name, self.age, self.grade))</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">s = student(<span class="string">'小马的程序人生'</span>, <span class="number">10</span>, <span class="number">60</span>, <span class="number">3</span>)</span><br><span class="line">s.speak()</span><br><span class="line"><span class="comment"># 小马的程序人生 说: 我 10 岁了,我在读 3 年级</span></span><br></pre></td></tr></table></figure>注意:如果上面的程序去掉:<code>people.__init__(self, n, a, w)</code>,则输出:说: <code>我 0 岁了,我在读 3 年级</code>,因为子类的构造方法继承了父类的变量。解决该问题可用以下两种方式:</li><li>调用未绑定的父类方法 <code>Fish.__init__(self)</code><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Shark</span><span class="params">(Fish)</span>:</span> <span class="comment"># 鲨鱼</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></span><br><span class="line"> Fish.__init__(self)</span><br><span class="line"> self.hungry = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">eat</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">if</span> self.hungry:</span><br><span class="line"> print(<span class="string">"吃货的梦想就是天天有得吃!"</span>)</span><br><span class="line"> self.hungry = <span class="literal">False</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">"太撑了,吃不下了!"</span>)</span><br><span class="line"> self.hungry = <span class="literal">True</span></span><br></pre></td></tr></table></figure></li><li>使用 <code>super</code> 函数 <code>super().__init__()</code><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Shark</span><span class="params">(Fish)</span>:</span> <span class="comment"># 鲨鱼</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"> self.hungry = <span class="literal">True</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">eat</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">if</span> self.hungry:</span><br><span class="line"> print(<span class="string">"吃货的梦想就是天天有得吃!"</span>)</span><br><span class="line"> self.hungry = <span class="literal">False</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">"太撑了,吃不下了!"</span>)</span><br><span class="line"> self.hungry = <span class="literal">True</span></span><br></pre></td></tr></table></figure><h2 id="类、类对象和实例对象"><a href="#类、类对象和实例对象" class="headerlink" title="类、类对象和实例对象"></a>类、类对象和实例对象</h2></li><li>类对象:创建一个类,其实也是一个对象也在内存开辟了一块空间,称为类对象,类对象只有一个。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">A</span><span class="params">(object)</span>:</span></span><br><span class="line"> <span class="keyword">pass</span></span><br></pre></td></tr></table></figure></li><li>实例对象:就是通过实例化类创建的对象,称为实例对象,实例对象可以有多个。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">A</span><span class="params">(object)</span>:</span></span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 实例化对象 a、b、c都属于实例对象。</span></span><br><span class="line">a = A()</span><br><span class="line">b = A()</span><br><span class="line">c = A()</span><br></pre></td></tr></table></figure></li><li>类属性:类里面方法外面定义的变量称为类属性。类属性所属于类对象并且多个实例对象之间共享同一个类属性,说白了就是类属性所有的通过该类实例化的对象都能共享。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">A</span><span class="params">()</span>:</span></span><br><span class="line"> a = <span class="number">0</span> <span class="comment">#类属性</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, xx)</span>:</span></span><br><span class="line"> A.a = xx <span class="comment">#使用类属性可以通过 (类名.类属性)调用。</span></span><br></pre></td></tr></table></figure></li><li>实例属性:实例属性和具体的某个实例对象有关系,并且一个实例对象和另外一个实例对象是不共享属性的,说白了实例属性只能在自己的对象里面使用,其他的对象不能直接使用,因为 <code>self</code> 是谁调用,它的值就属于该对象。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br><span class="line">41</span><br><span class="line">42</span><br><span class="line">43</span><br><span class="line">44</span><br><span class="line">45</span><br><span class="line">46</span><br><span class="line">47</span><br><span class="line">48</span><br><span class="line">49</span><br><span class="line">50</span><br><span class="line">51</span><br><span class="line">52</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 创建类对象</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Test</span><span class="params">(object)</span>:</span></span><br><span class="line"> class_attr = <span class="number">100</span> <span class="comment"># 类属性</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></span><br><span class="line"> self.sl_attr = <span class="number">100</span> <span class="comment"># 实例属性</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">func</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">'类对象.类属性的值:'</span>, Test.class_attr) <span class="comment"># 调用类属性</span></span><br><span class="line"> print(<span class="string">'self.类属性的值'</span>, self.class_attr) <span class="comment"># 相当于把类属性 变成实例属性</span></span><br><span class="line"> print(<span class="string">'self.实例属性的值'</span>, self.sl_attr) <span class="comment"># 调用实例属性</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">a = Test()</span><br><span class="line">a.func()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 类对象.类属性的值: 100</span></span><br><span class="line"><span class="comment"># self.类属性的值 100</span></span><br><span class="line"><span class="comment"># self.实例属性的值 100</span></span><br><span class="line"></span><br><span class="line">b = Test()</span><br><span class="line">b.func()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 类对象.类属性的值: 100</span></span><br><span class="line"><span class="comment"># self.类属性的值 100</span></span><br><span class="line"><span class="comment"># self.实例属性的值 100</span></span><br><span class="line"></span><br><span class="line">a.class_attr = <span class="number">200</span></span><br><span class="line">a.sl_attr = <span class="number">200</span></span><br><span class="line">a.func()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 类对象.类属性的值: 100</span></span><br><span class="line"><span class="comment"># self.类属性的值 200</span></span><br><span class="line"><span class="comment"># self.实例属性的值 200</span></span><br><span class="line"></span><br><span class="line">b.func()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 类对象.类属性的值: 100</span></span><br><span class="line"><span class="comment"># self.类属性的值 100</span></span><br><span class="line"><span class="comment"># self.实例属性的值 100</span></span><br><span class="line"></span><br><span class="line">Test.class_attr = <span class="number">300</span></span><br><span class="line">a.func()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 类对象.类属性的值: 300</span></span><br><span class="line"><span class="comment"># self.类属性的值 200</span></span><br><span class="line"><span class="comment"># self.实例属性的值 200</span></span><br><span class="line"></span><br><span class="line">b.func()</span><br><span class="line"><span class="comment"># 类对象.类属性的值: 300</span></span><br><span class="line"><span class="comment"># self.类属性的值 300</span></span><br><span class="line"><span class="comment"># self.实例属性的值 100</span></span><br></pre></td></tr></table></figure><h2 id="绑定"><a href="#绑定" class="headerlink" title="绑定"></a>绑定</h2>Python 严格要求方法需要有实例才能被调用,这种限制其实就是 Python 所谓的绑定概念。</li></ul><p>Python 对象的数据属性通常存储在名为 <code>.__ dict__</code> 的字典中,我们可以直接访问 <code>__dict__</code>,或利用 Python 的内置函数 <code>vars()</code> 获取 <code>.__ dict__</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">CC</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">setXY</span><span class="params">(self, x, y)</span>:</span></span><br><span class="line"> self.x = x</span><br><span class="line"> self.y = y</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">printXY</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(self.x, self.y)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">dd = CC()</span><br><span class="line">print(dd.__dict__)</span><br><span class="line"><span class="comment"># {}</span></span><br><span class="line"></span><br><span class="line">print(vars(dd))</span><br><span class="line"><span class="comment"># {}</span></span><br><span class="line"></span><br><span class="line">print(CC.__dict__)</span><br><span class="line"><span class="comment"># {'__module__': '__main__', 'setXY': <function CC.setXY at 0x000000C3473DA048>, 'printXY': <function CC.printXY at 0x000000C3473C4F28>, '__dict__': <attribute '__dict__' of 'CC' objects>, '__weakref__': <attribute '__weakref__' of 'CC' objects>, '__doc__': None}</span></span><br><span class="line"></span><br><span class="line">dd.setXY(<span class="number">4</span>, <span class="number">5</span>)</span><br><span class="line">print(dd.__dict__)</span><br><span class="line"><span class="comment"># {'x': 4, 'y': 5}</span></span><br><span class="line"></span><br><span class="line">print(vars(CC))</span><br><span class="line"><span class="comment"># {'__module__': '__main__', 'setXY': <function CC.setXY at 0x000000632CA9B048>, 'printXY': <function CC.printXY at 0x000000632CA83048>, '__dict__': <attribute '__dict__' of 'CC' objects>, '__weakref__': <attribute '__weakref__' of 'CC' objects>, '__doc__': None}</span></span><br><span class="line"></span><br><span class="line">print(CC.__dict__)</span><br><span class="line"><span class="comment"># {'__module__': '__main__', 'setXY': <function CC.setXY at 0x000000632CA9B048>, 'printXY': <function CC.printXY at 0x000000632CA83048>, '__dict__': <attribute '__dict__' of 'CC' objects>, '__weakref__': <attribute '__weakref__' of 'CC' objects>, '__doc__': None}</span></span><br></pre></td></tr></table></figure><h2 id="一些相关的内置函数(BIF)"><a href="#一些相关的内置函数(BIF)" class="headerlink" title="一些相关的内置函数(BIF)"></a>一些相关的内置函数(BIF)</h2><ul><li><code>issubclass(class, classinfo)</code> 方法用于判断参数 class 是否是类型参数 classinfo 的子类。</li><li>一个类被认为是其自身的子类。</li><li><code>classinfo</code> 可以是类对象的元组,只要 class 是其中任何一个候选类的子类,则返回 True。</li><li>isinstance(object, classinfo) 方法用于判断一个对象是否是一个已知的类型,类似type()。</li><li>type()不会认为子类是一种父类类型,不考虑继承关系。</li><li>isinstance()会认为子类是一种父类类型,考虑继承关系。</li><li>如果第一个参数不是对象,则永远返回False。</li><li>如果第二个参数不是类或者由类对象组成的元组,会抛出一个TypeError异常。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line">a = <span class="number">2</span></span><br><span class="line">print(isinstance(a, int)) <span class="comment"># True</span></span><br><span class="line">print(isinstance(a, str)) <span class="comment"># False</span></span><br><span class="line">print(isinstance(a, (str, int, list))) <span class="comment"># True</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">A</span>:</span></span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">B</span><span class="params">(A)</span>:</span></span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"></span><br><span class="line">print(isinstance(A(), A)) <span class="comment"># True</span></span><br><span class="line">print(type(A()) == A) <span class="comment"># True</span></span><br><span class="line">print(isinstance(B(), A)) <span class="comment"># True</span></span><br><span class="line">print(type(B()) == A) <span class="comment"># False</span></span><br></pre></td></tr></table></figure></li><li><code>hasattr(object, name)</code> 用于判断对象是否包含对应的属性。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Coordinate</span>:</span></span><br><span class="line"> x = <span class="number">10</span></span><br><span class="line"> y = <span class="number">-5</span></span><br><span class="line"> z = <span class="number">0</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">point1 = Coordinate()</span><br><span class="line">print(hasattr(point1, <span class="string">'x'</span>)) <span class="comment"># True</span></span><br><span class="line">print(hasattr(point1, <span class="string">'y'</span>)) <span class="comment"># True</span></span><br><span class="line">print(hasattr(point1, <span class="string">'z'</span>)) <span class="comment"># True</span></span><br><span class="line">print(hasattr(point1, <span class="string">'no'</span>)) <span class="comment"># False</span></span><br></pre></td></tr></table></figure></li><li><code>getattr(object, name[, default])</code> 用于返回一个对象属性值。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">A</span><span class="params">(object)</span>:</span></span><br><span class="line"> bar = <span class="number">1</span></span><br><span class="line"></span><br><span class="line">a = A()</span><br><span class="line">print(getattr(a, <span class="string">'bar'</span>)) <span class="comment"># 1</span></span><br><span class="line">print(getattr(a, <span class="string">'bar2'</span>, <span class="number">3</span>)) <span class="comment"># 3</span></span><br><span class="line">print(getattr(a, <span class="string">'bar2'</span>))</span><br><span class="line"><span class="comment"># AttributeError: 'A' object has no attribute 'bar2'</span></span><br></pre></td></tr></table></figure></li><li><code>setattr(object, name, value)</code> 对应函数 <code>getattr()</code>,用于设置属性值,该属性不一定是存在的。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">A</span><span class="params">(object)</span>:</span></span><br><span class="line"> bar = <span class="number">1</span></span><br><span class="line"></span><br><span class="line">a = A()</span><br><span class="line">print(getattr(a, <span class="string">'bar'</span>)) <span class="comment"># 1</span></span><br><span class="line">setattr(a, <span class="string">'bar'</span>, <span class="number">5</span>)</span><br><span class="line">print(a.bar) <span class="comment"># 5</span></span><br><span class="line">setattr(a, <span class="string">"age"</span>, <span class="number">28</span>)</span><br><span class="line">print(a.age) <span class="comment"># 28</span></span><br></pre></td></tr></table></figure></li><li><code>delattr(object, name)</code> 用于删除属性。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Coordinate</span>:</span></span><br><span class="line"> x = <span class="number">10</span></span><br><span class="line"> y = <span class="number">-5</span></span><br><span class="line"> z = <span class="number">0</span></span><br><span class="line"></span><br><span class="line"></span><br><span class="line">point1 = Coordinate()</span><br><span class="line"></span><br><span class="line">print(<span class="string">'x = '</span>, point1.x) <span class="comment"># x = 10</span></span><br><span class="line">print(<span class="string">'y = '</span>, point1.y) <span class="comment"># y = -5</span></span><br><span class="line">print(<span class="string">'z = '</span>, point1.z) <span class="comment"># z = 0</span></span><br><span class="line"></span><br><span class="line">delattr(Coordinate, <span class="string">'z'</span>)</span><br><span class="line"></span><br><span class="line">print(<span class="string">'--删除 z 属性后--'</span>) <span class="comment"># --删除 z 属性后--</span></span><br><span class="line">print(<span class="string">'x = '</span>, point1.x) <span class="comment"># x = 10</span></span><br><span class="line">print(<span class="string">'y = '</span>, point1.y) <span class="comment"># y = -5</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 触发错误</span></span><br><span class="line">print(<span class="string">'z = '</span>, point1.z)</span><br><span class="line"><span class="comment"># AttributeError: 'Coordinate' object has no attribute 'z'</span></span><br></pre></td></tr></table></figure><h1 id="魔术方法-1"><a href="#魔术方法-1" class="headerlink" title="魔术方法"></a>魔术方法</h1>魔法方法的第一个参数应为 <code>cls</code>(类方法) 或者 <code>self</code>(实例方法)。</li><li><code>cls</code>:代表一个类的名称</li><li><code>self</code>:代表一个实例对象的名称<h2 id="基本的魔法方法"><a href="#基本的魔法方法" class="headerlink" title="基本的魔法方法"></a>基本的魔法方法</h2></li><li><code>__init__(self[, ...])</code> 构造器,当一个实例被创建的时候调用的初始化方法。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Rectangle</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, x, y)</span>:</span></span><br><span class="line"> self.x = x</span><br><span class="line"> self.y = y</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">getPeri</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> (self.x + self.y) * <span class="number">2</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">getArea</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> self.x * self.y</span><br><span class="line"></span><br><span class="line">rect = Rectangle(<span class="number">4</span>, <span class="number">5</span>)</span><br><span class="line">print(rect.getPeri()) <span class="comment"># 18</span></span><br><span class="line">print(rect.getArea()) <span class="comment"># 20</span></span><br></pre></td></tr></table></figure></li><li><code>__str__(self)</code>:<ul><li>当你打印一个对象的时候,触发 <code>__str__</code></li><li>当你使用 <code>%s</code> 格式化的时候,触发 <code>__str__</code></li><li><code>str</code> 强转数据类型的时候,触发 <code>__str__</code></li></ul></li><li><code>__repr__(self)</code>:<ul><li><code>repr</code> 是 <code>str</code> 的备胎</li><li>有 <code>__str__</code> 的时候执行 <code>__str__</code>,没有实现 <code>__str__</code> 的时候,执行 <code>__repr__</code></li><li><code>repr(obj)</code> 内置函数对应的结果是 <code>__repr__</code> 的返回值</li><li>当你使用 <code>%r</code> 格式化的时候 触发 <code>__repr__</code><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Cat</span>:</span></span><br><span class="line"> <span class="string">"""定义一个猫类"""</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, new_name, new_age)</span>:</span></span><br><span class="line"> <span class="string">"""在创建完对象之后 会自动调用, 它完成对象的初始化的功能"""</span></span><br><span class="line"> self.name = new_name</span><br><span class="line"> self.age = new_age</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__str__</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="string">"""返回一个对象的描述信息"""</span></span><br><span class="line"> <span class="keyword">return</span> <span class="string">"名字是:%s , 年龄是:%d"</span> % (self.name, self.age)</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__repr__</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="string">"""返回一个对象的描述信息"""</span></span><br><span class="line"> <span class="keyword">return</span> <span class="string">"Cat:(%s,%d)"</span> % (self.name, self.age)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">eat</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">"%s在吃鱼...."</span> % self.name)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">drink</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">"%s在喝可乐..."</span> % self.name)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">introduce</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">"名字是:%s, 年龄是:%d"</span> % (self.name, self.age))</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="comment"># 创建了一个对象</span></span><br><span class="line">tom = Cat(<span class="string">"汤姆"</span>, <span class="number">30</span>)</span><br><span class="line">print(tom) <span class="comment"># 名字是:汤姆 , 年龄是:30</span></span><br><span class="line">print(str(tom)) <span class="comment"># 名字是:汤姆 , 年龄是:30</span></span><br><span class="line">print(repr(tom)) <span class="comment"># Cat:(汤姆,30)</span></span><br><span class="line">tom.eat() <span class="comment"># 汤姆在吃鱼....</span></span><br><span class="line">tom.introduce() <span class="comment"># 名字是:汤姆, 年龄是:30</span></span><br></pre></td></tr></table></figure><h2 id="算术运算符"><a href="#算术运算符" class="headerlink" title="算术运算符"></a>算术运算符</h2></li></ul></li><li><code>__add__(self, other)</code> 定义加法的行为:<code>+</code></li><li><code>__sub__(self, other)</code> 定义减法的行为:<code>-</code><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">MyClass</span>:</span></span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, height, weight)</span>:</span></span><br><span class="line"> self.height = height</span><br><span class="line"> self.weight = weight</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 两个对象的长相加,宽不变.返回一个新的类</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__add__</span><span class="params">(self, others)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> MyClass(self.height + others.height, self.weight + others.weight)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 两个对象的宽相减,长不变.返回一个新的类</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__sub__</span><span class="params">(self, others)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> MyClass(self.height - others.height, self.weight - others.weight)</span><br><span class="line"></span><br><span class="line"> <span class="comment"># 说一下自己的参数</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">intro</span><span class="params">(self)</span>:</span></span><br><span class="line"> print(<span class="string">"高为"</span>, self.height, <span class="string">" 重为"</span>, self.weight)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">main</span><span class="params">()</span>:</span></span><br><span class="line"> a = MyClass(height=<span class="number">10</span>, weight=<span class="number">5</span>)</span><br><span class="line"> a.intro()</span><br><span class="line"></span><br><span class="line"> b = MyClass(height=<span class="number">20</span>, weight=<span class="number">10</span>)</span><br><span class="line"> b.intro()</span><br><span class="line"></span><br><span class="line"> c = b - a</span><br><span class="line"> c.intro()</span><br><span class="line"></span><br><span class="line"> d = a + b</span><br><span class="line"> d.intro()</span><br><span class="line"></span><br><span class="line"></span><br><span class="line"><span class="keyword">if</span> __name__ == <span class="string">'__main__'</span>:</span><br><span class="line"> main()</span><br><span class="line"></span><br><span class="line"><span class="comment"># 高为 10 重为 5</span></span><br><span class="line"><span class="comment"># 高为 20 重为 10</span></span><br><span class="line"><span class="comment"># 高为 10 重为 5</span></span><br><span class="line"><span class="comment"># 高为 30 重为 15</span></span><br></pre></td></tr></table></figure></li><li><code>__mul__(self, other)</code> 定义乘法的行为:<code>*</code></li><li><code>__truediv__(self, other)</code> 定义真除法的行为:<code>/</code></li><li><code>__floordiv__(self, other)</code> 定义整数除法的行为:<code>//</code></li><li><code>__mod__(self, other)</code> 定义取模算法的行为:<code>%</code></li><li><code>__divmod__(self, other)</code> 定义当被 <code>divmod()</code> 调用时的行为</li><li><code>divmod(a, b)</code> 把除数和余数运算结果结合起来,返回一个包含商和余数的元组 <code>(a // b, a % b)</code>。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">print(divmod(<span class="number">7</span>, <span class="number">2</span>)) <span class="comment"># (3, 1)</span></span><br><span class="line">print(divmod(<span class="number">8</span>, <span class="number">2</span>)) <span class="comment"># (4, 0)</span></span><br></pre></td></tr></table></figure></li><li><code>__pow__(self, other[, module])</code> 定义当被 <code>power()</code> 调用或 <code>**</code> 运算时的行为</li><li><code>__lshift__(self, other)</code> 定义按位左移位的行为:<code><<</code></li><li><code>__rshift__(self, other)</code> 定义按位右移位的行为:<code>>></code></li><li><code>__and__(self, other)</code> 定义按位与操作的行为:<code>&</code></li><li><code>__xor__(self, other)</code> 定义按位异或操作的行为:<code>^</code></li><li><code>__or__(self, other)</code> 定义按位或操作的行为:<code>|</code><h2 id="增量赋值运算符"><a href="#增量赋值运算符" class="headerlink" title="增量赋值运算符"></a>增量赋值运算符</h2></li><li><code>__iadd__(self, other)</code> 定义赋值加法的行为:<code>+=</code></li><li><code>__isub__(self, other)</code> 定义赋值减法的行为:<code>-=</code></li><li><code>__imul__(self, other)</code> 定义赋值乘法的行为:<code>*=</code></li><li><code>__itruediv__(self, other)</code> 定义赋值真除法的行为:<code>/=</code></li><li><code>__ifloordiv__(self, other)</code> 定义赋值整数除法的行为:<code>//=</code></li><li><code>__imod__(self, other)</code> 定义赋值取模算法的行为:<code>%=</code></li><li><code>__ipow__(self, other[, modulo])</code> 定义赋值幂运算的行为:<code>**=</code></li><li><code>__ilshift__(self, other)</code> 定义赋值按位左移位的行为:<code><<=</code></li><li><code>__irshift__(self, other)</code> 定义赋值按位右移位的行为:<code>>>=</code></li><li><code>__iand__(self, other)</code> 定义赋值按位与操作的行为:<code>&=</code></li><li><code>__ixor__(self, other)</code> 定义赋值按位异或操作的行为:<code>^=</code></li><li><code>__ior__(self, other)</code> 定义赋值按位或操作的行为:<code>|=</code><h2 id="一元运算符"><a href="#一元运算符" class="headerlink" title="一元运算符"></a>一元运算符</h2></li><li><code>__neg__(self)</code> 定义正号的行为:<code>+x</code></li><li><code>__pos__(self)</code> 定义负号的行为:<code>-x</code></li><li><code>__abs__(self)</code> 定义当被 <code>abs()</code> 调用时的行为<h2 id="属性访问"><a href="#属性访问" class="headerlink" title="属性访问"></a>属性访问</h2></li><li><code>__getattr__(self, name)</code>: 定义当用户试图获取一个不存在的属性时的行为。</li><li><code>__getattribute__(self, name)</code>:定义当该类的属性被访问时的行为(先调用该方法,查看是否存在该属性,若不存在,接着去调用 <code>__getattr__</code>)。</li><li><code>__setattr__(self, name, value)</code>:定义当一个属性被设置时的行为。</li><li><code>__delattr__(self, name)</code>:定义当一个属性被删除时的行为。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">C</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__getattribute__</span><span class="params">(self, item)</span>:</span></span><br><span class="line"> print(<span class="string">'__getattribute__'</span>)</span><br><span class="line"> <span class="keyword">return</span> super().__getattribute__(item)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__getattr__</span><span class="params">(self, item)</span>:</span></span><br><span class="line"> print(<span class="string">'__getattr__'</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__setattr__</span><span class="params">(self, key, value)</span>:</span></span><br><span class="line"> print(<span class="string">'__setattr__'</span>)</span><br><span class="line"> super().__setattr__(key, value)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__delattr__</span><span class="params">(self, item)</span>:</span></span><br><span class="line"> print(<span class="string">'__delattr__'</span>)</span><br><span class="line"> super().__delattr__(item)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">c = C()</span><br><span class="line">c.x</span><br><span class="line"><span class="comment"># __getattribute__</span></span><br><span class="line"><span class="comment"># __getattr__</span></span><br><span class="line"></span><br><span class="line">c.x = <span class="number">1</span></span><br><span class="line"><span class="comment"># __setattr__</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">del</span> c.x</span><br><span class="line"><span class="comment"># __delattr__</span></span><br></pre></td></tr></table></figure><h2 id="定制序列"><a href="#定制序列" class="headerlink" title="定制序列"></a>定制序列</h2></li><li>如果说你希望定制的容器是不可变的话,你只需要定义 <code>__len__()</code> 和 <code>__getitem__()</code> 方法。</li><li>如果你希望定制的容器是可变的话,除了 <code>__len__()</code> 和 <code>__getitem__()</code> 方法,你还需要定义 <code>__setitem__()</code> 和 <code>__delitem__()</code> 两个方法。</li></ul><p><strong>编写一个不可改变的自定义列表,要求记录列表中每个元素被访问的次数。</strong></p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">CountList</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, *args)</span>:</span></span><br><span class="line"> self.values = [x <span class="keyword">for</span> x <span class="keyword">in</span> args]</span><br><span class="line"> self.count = {}.fromkeys(range(len(self.values)), <span class="number">0</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__len__</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> len(self.values)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__getitem__</span><span class="params">(self, item)</span>:</span></span><br><span class="line"> self.count[item] += <span class="number">1</span></span><br><span class="line"> <span class="keyword">return</span> self.values[item]</span><br><span class="line"></span><br><span class="line">c1 = CountList(<span class="number">1</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">7</span>, <span class="number">9</span>)</span><br><span class="line">c2 = CountList(<span class="number">2</span>, <span class="number">4</span>, <span class="number">6</span>, <span class="number">8</span>, <span class="number">10</span>)</span><br><span class="line">print(c1[<span class="number">1</span>]) <span class="comment"># 3</span></span><br><span class="line">print(c2[<span class="number">2</span>]) <span class="comment"># 6</span></span><br><span class="line">print(c1[<span class="number">1</span>] + c2[<span class="number">1</span>]) <span class="comment"># 7</span></span><br><span class="line"></span><br><span class="line">print(c1.count)</span><br><span class="line"><span class="comment"># {0: 0, 1: 2, 2: 0, 3: 0, 4: 0}</span></span><br><span class="line"></span><br><span class="line">print(c2.count)</span><br><span class="line"><span class="comment"># {0: 0, 1: 1, 2: 1, 3: 0, 4: 0}</span></span><br></pre></td></tr></table></figure><ul><li><code>__len__(self)</code> 定义当被 <code>len()</code> 调用时的行为(返回容器中元素的个数)。</li><li><code>__getitem__(self, key)</code> 定义获取容器中元素的行为,相当于 <code>self[key]</code>。</li><li><code>__setitem__(self, key, value)</code> 定义设置容器中指定元素的行为,相当于<code>self[key] = value</code>。</li><li><code>__delitem__(self, key)</code> 定义删除容器中指定元素的行为,相当于 <code>del self[key]</code>。</li></ul><p><strong>编写一个可改变的自定义列表,要求记录列表中每个元素被访问的次数。</strong></p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">CountList</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, *args)</span>:</span></span><br><span class="line"> self.values = [x <span class="keyword">for</span> x <span class="keyword">in</span> args]</span><br><span class="line"> self.count = {}.fromkeys(range(len(self.values)), <span class="number">0</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__len__</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> len(self.values)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__getitem__</span><span class="params">(self, item)</span>:</span></span><br><span class="line"> self.count[item] += <span class="number">1</span></span><br><span class="line"> <span class="keyword">return</span> self.values[item]</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__setitem__</span><span class="params">(self, key, value)</span>:</span></span><br><span class="line"> self.values[key] = value</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__delitem__</span><span class="params">(self, key)</span>:</span></span><br><span class="line"> <span class="keyword">del</span> self.values[key]</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">0</span>, len(self.values)):</span><br><span class="line"> <span class="keyword">if</span> i >= key:</span><br><span class="line"> self.count[i] = self.count[i + <span class="number">1</span>]</span><br><span class="line"> self.count.pop(len(self.values))</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">c1 = CountList(<span class="number">1</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">7</span>, <span class="number">9</span>)</span><br><span class="line">c2 = CountList(<span class="number">2</span>, <span class="number">4</span>, <span class="number">6</span>, <span class="number">8</span>, <span class="number">10</span>)</span><br><span class="line">print(c1[<span class="number">1</span>]) <span class="comment"># 3</span></span><br><span class="line">print(c2[<span class="number">2</span>]) <span class="comment"># 6</span></span><br><span class="line">c2[<span class="number">2</span>] = <span class="number">12</span></span><br><span class="line">print(c1[<span class="number">1</span>] + c2[<span class="number">2</span>]) <span class="comment"># 15</span></span><br><span class="line">print(c1.count)</span><br><span class="line"><span class="comment"># {0: 0, 1: 2, 2: 0, 3: 0, 4: 0}</span></span><br><span class="line">print(c2.count)</span><br><span class="line"><span class="comment"># {0: 0, 1: 0, 2: 2, 3: 0, 4: 0}</span></span><br><span class="line"><span class="keyword">del</span> c1[<span class="number">1</span>]</span><br><span class="line">print(c1.count)</span><br><span class="line"><span class="comment"># {0: 0, 1: 0, 2: 0, 3: 0}</span></span><br></pre></td></tr></table></figure><h1 id="迭代器"><a href="#迭代器" class="headerlink" title="迭代器"></a>迭代器</h1><ul><li>迭代是 Python 最强大的功能之一,是访问集合元素的一种方式。</li><li>迭代器是一个可以记住遍历的位置的对象。</li><li>迭代器对象从集合的第一个元素开始访问,直到所有的元素被访问完结束。</li><li>迭代器只能往前不会后退。</li><li>字符串,列表或元组对象都可用于创建迭代器:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">links = {<span class="string">'B'</span>: <span class="string">'百度'</span>, <span class="string">'A'</span>: <span class="string">'阿里'</span>, <span class="string">'T'</span>: <span class="string">'腾讯'</span>}</span><br><span class="line"><span class="keyword">for</span> each <span class="keyword">in</span> iter(links):</span><br><span class="line"> print(<span class="string">'%s -> %s'</span> % (each, links[each]))</span><br><span class="line"><span class="comment"># B -> 百度</span></span><br><span class="line"><span class="comment"># A -> 阿里</span></span><br><span class="line"><span class="comment"># T -> 腾讯</span></span><br><span class="line"><span class="comment"># B -> 百度</span></span><br><span class="line"><span class="comment"># A -> 阿里</span></span><br><span class="line"><span class="comment"># T -> 腾讯</span></span><br></pre></td></tr></table></figure></li><li>迭代器有两个基本的方法:<code>iter()</code> 和 <code>next()</code>。</li><li><code>iter(object)</code> 函数用来生成迭代器。</li><li><code>next(iterator[, default])</code> 返回迭代器的下一个项目。</li><li><code>iterator</code> – 可迭代对象</li><li><code>default</code> – 可选,用于设置在没有下一个元素时返回该默认值,如果不设置,又没有下一个元素则会触发 <code>StopIteration</code> 异常。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">links = {<span class="string">'B'</span>: <span class="string">'百度'</span>, <span class="string">'A'</span>: <span class="string">'阿里'</span>, <span class="string">'T'</span>: <span class="string">'腾讯'</span>}</span><br><span class="line"></span><br><span class="line">it = iter(links)</span><br><span class="line"><span class="keyword">while</span> <span class="literal">True</span>:</span><br><span class="line"> <span class="keyword">try</span>:</span><br><span class="line"> each = next(it)</span><br><span class="line"> <span class="keyword">except</span> StopIteration:</span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"> print(each)</span><br><span class="line"></span><br><span class="line"><span class="comment"># B</span></span><br><span class="line"><span class="comment"># A</span></span><br><span class="line"><span class="comment"># T</span></span><br><span class="line"></span><br><span class="line">it = iter(links)</span><br><span class="line">print(next(it)) <span class="comment"># B</span></span><br><span class="line">print(next(it)) <span class="comment"># A</span></span><br><span class="line">print(next(it)) <span class="comment"># T</span></span><br><span class="line">print(next(it)) <span class="comment"># StopIteration</span></span><br></pre></td></tr></table></figure>把一个类作为一个迭代器使用需要在类中实现两个魔法方法 <code>__iter__()</code> 与 <code>__next__()</code>。</li><li><code>__iter__(self)</code> 定义当迭代容器中的元素的行为,返回一个特殊的迭代器对象, 这个迭代器对象实现了 <code>__next__()</code> 方法并通过 <code>StopIteration</code> 异常标识迭代的完成。</li><li><code>__next__()</code> 返回下一个迭代器对象。</li><li><code>StopIteration</code> 异常用于标识迭代的完成,防止出现无限循环的情况,在 <code>__next__()</code> 方法中我们可以设置在完成指定循环次数后触发 <code>StopIteration</code> 异常来结束迭代。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Fibs</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, n=<span class="number">10</span>)</span>:</span></span><br><span class="line"> self.a = <span class="number">0</span></span><br><span class="line"> self.b = <span class="number">1</span></span><br><span class="line"> self.n = n</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__iter__</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">return</span> self</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__next__</span><span class="params">(self)</span>:</span></span><br><span class="line"> self.a, self.b = self.b, self.a + self.b</span><br><span class="line"> <span class="keyword">if</span> self.a > self.n:</span><br><span class="line"> <span class="keyword">raise</span> StopIteration</span><br><span class="line"> <span class="keyword">return</span> self.a</span><br><span class="line"></span><br><span class="line">fibs = Fibs(<span class="number">100</span>)</span><br><span class="line"><span class="keyword">for</span> each <span class="keyword">in</span> fibs:</span><br><span class="line"> print(each, end=<span class="string">' '</span>)</span><br><span class="line"><span class="comment"># 1 1 2 3 5 8 13 21 34 55 89</span></span><br></pre></td></tr></table></figure><h1 id="生成器"><a href="#生成器" class="headerlink" title="生成器"></a>生成器</h1></li><li>在 Python 中,使用了 <code>yield</code> 的函数被称为生成器(generator)。</li><li>跟普通函数不同的是,生成器是一个返回迭代器的函数,只能用于迭代操作,更简单点理解生成器就是一个迭代器。</li><li>在调用生成器运行的过程中,每次遇到 <code>yield</code> 时函数会暂停并保存当前所有的运行信息,返回 <code>yield</code> 的值, 并在下一次执行 <code>next()</code> 方法时从当前位置继续运行。</li><li>调用一个生成器函数,返回的是一个迭代器对象。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">myGen</span><span class="params">()</span>:</span></span><br><span class="line"> print(<span class="string">'生成器执行!'</span>)</span><br><span class="line"> <span class="keyword">yield</span> <span class="number">1</span></span><br><span class="line"> <span class="keyword">yield</span> <span class="number">2</span></span><br><span class="line"> </span><br><span class="line">myG = myGen()</span><br><span class="line"><span class="keyword">for</span> each <span class="keyword">in</span> myG:</span><br><span class="line"> print(each)</span><br><span class="line"></span><br><span class="line"><span class="string">'''</span></span><br><span class="line"><span class="string">生成器执行!</span></span><br><span class="line"><span class="string">1</span></span><br><span class="line"><span class="string">2</span></span><br><span class="line"><span class="string">'''</span></span><br><span class="line"></span><br><span class="line">myG = myGen()</span><br><span class="line">print(next(myG)) </span><br><span class="line"><span class="comment"># 生成器执行!</span></span><br><span class="line"><span class="comment"># 1</span></span><br><span class="line"></span><br><span class="line">print(next(myG)) <span class="comment"># 2</span></span><br><span class="line">print(next(myG)) <span class="comment"># StopIteration</span></span><br></pre></td></tr></table></figure></li></ul>]]></content>
<tags>
<tag> python </tag>
</tags>
</entry>
<entry>
<title>[DSU&阿里云天池] Python训练营 Task 2</title>
<link href="2020/12/21/DSU-%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-Python%E8%AE%AD%E7%BB%83%E8%90%A5-Task-2/"/>
<url>2020/12/21/DSU-%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-Python%E8%AE%AD%E7%BB%83%E8%90%A5-Task-2/</url>
<content type="html"><![CDATA[<p>[TOC]</p><h1 id="列表"><a href="#列表" class="headerlink" title="列表"></a>列表</h1><h2 id="定义"><a href="#定义" class="headerlink" title="定义"></a>定义</h2><p>列表是有序不定长集合语法为 <code>[元素 1, 元素 2, 元素 3,..., 元素 N]</code></p><h2 id="列表的创建"><a href="#列表的创建" class="headerlink" title="列表的创建"></a>列表的创建</h2><h3 id="创建一个普通列表"><a href="#创建一个普通列表" class="headerlink" title="创建一个普通列表"></a>创建一个普通列表</h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">a = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>] <span class="comment"># [1, 2, 3]</span></span><br><span class="line">b = list(<span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>) <span class="comment"># [4, 5, 6]</span></span><br></pre></td></tr></table></figure><h3 id="使用列表解析式"><a href="#使用列表解析式" class="headerlink" title="使用列表解析式"></a>使用列表解析式</h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">[i <span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">10</span>)] <span class="comment"># [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]</span></span><br></pre></td></tr></table></figure><p><strong>列表内的元素可以是不同类型。</strong></p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">mix = [<span class="number">1</span>, <span class="string">'lsgo'</span>, <span class="number">3.14</span>, [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>]] <span class="comment"># mix = [1, 'lsgo', 3.14, [1, 2, 3]]</span></span><br></pre></td></tr></table></figure><a id="more"></a><h3 id="创建一个空列表。"><a href="#创建一个空列表。" class="headerlink" title="创建一个空列表。"></a>创建一个空列表。</h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">lst = [] <span class="comment"># []</span></span><br><span class="line">lst2 = list() <span class="comment"># []</span></span><br></pre></td></tr></table></figure><h3 id="列表的方法"><a href="#列表的方法" class="headerlink" title="列表的方法"></a>列表的方法</h3><ul><li>向列表结尾添加一个元素:<code>lst.append(obj)</code><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>]</span><br><span class="line">lst.append(<span class="number">4</span>) <span class="comment"># [1, 2, 3, 4]</span></span><br></pre></td></tr></table></figure></li><li>向列表结尾添加多个元素:<code>lst.extend(obj)</code><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>]</span><br><span class="line">lst.extend([<span class="number">4</span>, <span class="number">5</span>]) <span class="comment"># [1, 2, 3, 4, 5]</span></span><br></pre></td></tr></table></figure>注意 <code>append</code> 和 <code>extend</code> 的区别:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>]</span><br><span class="line">lst.append([<span class="number">4</span>, <span class="number">5</span>]) <span class="comment"># [1, 2, 3, [4, 5]]</span></span><br></pre></td></tr></table></figure></li><li>删除列表中的指定元素:<code>lst.remove(obj)</code><br><code>lst.remove()</code> 删除列表中的第一个匹配项。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">1</span>]</span><br><span class="line">lst.remove(<span class="number">1</span>) <span class="comment"># [2, 3, 1]</span></span><br></pre></td></tr></table></figure></li><li>删除列表中指定位置的元素:<code>lst.pop([index=-1])</code><br>通常省略<code>index=-1</code>,即移除最后一个值,并返回它。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>]</span><br><span class="line">y = lst.pop()</span><br><span class="line">print(lst, y)</span><br><span class="line"><span class="comment"># [1, 2] 3</span></span><br></pre></td></tr></table></figure></li><li>在指定位置 <code>index</code> 插入元素:<code>lst.insert(index, obj)</code><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>]</span><br><span class="line">lst.insert(<span class="number">0</span>, <span class="number">4</span>)</span><br><span class="line">print(lst)</span><br><span class="line"><span class="comment"># [4, 1, 2, 3]</span></span><br></pre></td></tr></table></figure><h3 id="获取列表中的元素"><a href="#获取列表中的元素" class="headerlink" title="获取列表中的元素"></a>获取列表中的元素</h3></li><li>通过元素的索引值,从列表获取单个元素,注意,列表索引值是从0开始的。</li><li>通过将索引指定为-1,可让Python返回最后一个列表元素,索引 -2 返回倒数第二个列表元素,以此类推。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>]</span><br><span class="line">print(lst[<span class="number">0</span>])</span><br><span class="line"><span class="comment"># [1]</span></span><br></pre></td></tr></table></figure>切片的通用写法是 start : stop : step</li><li>情况 1 - “start :”</li></ul><p>以 step 为 1 (默认) 从编号 start 往列表尾部切片。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>]</span><br><span class="line">print(lst[<span class="number">2</span>:])</span><br><span class="line"><span class="comment"># [3, 4, 5]</span></span><br></pre></td></tr></table></figure><ul><li>情况 2 - “: stop”</li></ul><p>以 step 为 1 (默认) 从列表头部往编号 stop 切片。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>]</span><br><span class="line">print(lst[:<span class="number">2</span>])</span><br><span class="line"><span class="comment"># [1, 2]</span></span><br></pre></td></tr></table></figure><ul><li>情况 3 - “start : stop”</li></ul><p>以 step 为 1 (默认) 从编号 start 往编号 stop 切片。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>]</span><br><span class="line">print(lst[<span class="number">1</span>:<span class="number">3</span>])</span><br><span class="line"><span class="comment"># [2, 3]</span></span><br></pre></td></tr></table></figure><ul><li>情况 4 - “start : stop : step”</li></ul><p>以具体的 step 从编号 start 往编号 stop 切片。注意最后把 step 设为 -1,相当于将列表反向排列。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">lst = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>]</span><br><span class="line">print(lst[<span class="number">1</span>:<span class="number">5</span>:<span class="number">2</span>])</span><br><span class="line"><span class="comment"># [2, 4]</span></span><br></pre></td></tr></table></figure><ul><li>情况 5 - “ : “<br>复制列表中的所有元素(浅拷贝)。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">lst1 = [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>]</span><br><span class="line">lst2 = lst1[:]</span><br><span class="line">lst2[<span class="number">0</span>] = <span class="number">4</span></span><br><span class="line">print(lst1, lst2)</span><br><span class="line"><span class="comment"># [1, 2, 3] [4, 2, 3]</span></span><br></pre></td></tr></table></figure><h3 id="列表的常用操作符"><a href="#列表的常用操作符" class="headerlink" title="列表的常用操作符"></a>列表的常用操作符</h3></li><li>等号操作符:<code>==</code></li><li>连接操作符:<code>+</code></li><li>重复操作符:<code>*</code></li><li>成员关系操作符:<code>in</code>、<code>not in</code></li></ul><p>等号 <code>==</code>,只有成员、成员位置都相同时才返回 <code>True</code>。</p><p>列表拼接有两种方式,用加号 <code>+</code>和乘号 <code>*</code>,前者首尾拼接,后者复制拼接。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">list1 = [<span class="number">123</span>, <span class="number">456</span>]</span><br><span class="line">list2 = [<span class="number">456</span>, <span class="number">123</span>]</span><br><span class="line">list3 = [<span class="number">123</span>, <span class="number">456</span>]</span><br><span class="line"></span><br><span class="line">print(list1 == list2) <span class="comment"># False</span></span><br><span class="line">print(list1 == list3) <span class="comment"># True</span></span><br><span class="line"></span><br><span class="line">list4 = list1 + list2 <span class="comment"># extend()</span></span><br><span class="line">print(list4) <span class="comment"># [123, 456, 456, 123]</span></span><br><span class="line"></span><br><span class="line">list5 = list3 * <span class="number">3</span></span><br><span class="line">print(list5) <span class="comment"># [123, 456, 123, 456, 123, 456]</span></span><br><span class="line"></span><br><span class="line">list3 *= <span class="number">3</span></span><br><span class="line">print(list3) <span class="comment"># [123, 456, 123, 456, 123, 456]</span></span><br><span class="line"></span><br><span class="line">print(<span class="number">123</span> <span class="keyword">in</span> list3) <span class="comment"># True</span></span><br><span class="line">print(<span class="number">456</span> <span class="keyword">not</span> <span class="keyword">in</span> list3) <span class="comment"># False</span></span><br></pre></td></tr></table></figure><h3 id="列表的其它方法"><a href="#列表的其它方法" class="headerlink" title="列表的其它方法"></a>列表的其它方法</h3><ul><li><code>lst.count(obj)</code>:统计某个元素在列表中出现的次数<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">list1 = [<span class="number">123</span>, <span class="number">456</span>] * <span class="number">3</span> <span class="comment"># [123, 456, 123, 456, 123, 456]</span></span><br><span class="line">num = list1.count(<span class="number">123</span>)</span><br><span class="line">print(num) <span class="comment"># 3</span></span><br></pre></td></tr></table></figure></li><li><code>lst.reverse()</code>:从列表中找出某个值第一个匹配项的索引位置<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">list1 = [<span class="number">123</span>, <span class="number">456</span>] * <span class="number">5</span></span><br><span class="line">print(list1.index(<span class="number">123</span>)) <span class="comment"># 0</span></span><br><span class="line">print(list1.index(<span class="number">123</span>, <span class="number">1</span>)) <span class="comment"># 2</span></span><br><span class="line">print(list1.index(<span class="number">123</span>, <span class="number">3</span>, <span class="number">7</span>)) <span class="comment"># 4</span></span><br></pre></td></tr></table></figure></li><li><code>lst.index(obj)</code>:反向列表中元素<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">x = [<span class="number">123</span>, <span class="number">456</span>, <span class="number">789</span>]</span><br><span class="line">x.reverse() <span class="comment"># [789, 456, 123]</span></span><br></pre></td></tr></table></figure></li><li><code>lst.sort(key=None, rverse=False)</code>:对原列表进行排序。<ul><li>key – 主要是用来进行比较的元素,只有一个参数,具体的函数的参数就是取自于可迭代对象中,指定可迭代对象中的一个元素来进行排序。</li><li>reverse – 排序规则,reverse = True 降序, reverse = False 升序(默认)。</li><li>该方法没有返回值,但是会对列表的对象进行排序。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">x = [<span class="number">123</span>, <span class="number">456</span>, <span class="number">789</span>, <span class="number">213</span>]</span><br><span class="line">x.sort() <span class="comment"># [123, 213, 456, 789]</span></span><br><span class="line"></span><br><span class="line">x.sort(reverse=<span class="literal">True</span>) <span class="comment"># [789, 456, 213, 123]</span></span><br><span class="line"></span><br><span class="line">x.sort(key=<span class="keyword">lambda</span> a: a[<span class="number">0</span>]) <span class="comment"># [(1, 3), (2, 2), (3, 4), (4, 1)]</span></span><br></pre></td></tr></table></figure><h1 id="元组"><a href="#元组" class="headerlink" title="元组"></a>元组</h1>元组的语法:<code>(元素1, 元素2,元素3,...,元素N)</code><h2 id="创建和访问一个元组"><a href="#创建和访问一个元组" class="headerlink" title="创建和访问一个元组"></a>创建和访问一个元组</h2></li></ul></li><li>Python 的元组与列表类似,不同之处在于tuple被创建后就不能对其进行修改,类似字符串。</li><li>元组使用小括号,列表使用方括号。</li><li>元组与列表类似,也用整数来对它进行索引 (indexing) 和切片 (slicing)。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">t1 = (<span class="number">1</span>, <span class="number">2</span>, <span class="string">'a'</span>) <span class="comment"># (1, 2, 'a')</span></span><br><span class="line">t2 = (<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>, <span class="number">7</span>)</span><br><span class="line">t2[<span class="number">1</span>] <span class="comment"># 2</span></span><br><span class="line">t2[<span class="number">4</span>:] <span class="comment"># (5, 6, 7)</span></span><br><span class="line">t2[:<span class="number">3</span>] <span class="comment"># (1, 2, 3)</span></span><br></pre></td></tr></table></figure></li><li>创建元组可以用小括号 (),也可以什么都不用,为了可读性,建议还是用 ()。</li><li>元组中只包含一个元素时,需要在元素后面添加逗号,否则括号会被当作运算符使用。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">x = (<span class="number">1</span>)</span><br><span class="line">type(x) <span class="comment"># int</span></span><br><span class="line"></span><br><span class="line">x = ()</span><br><span class="line">type(x) <span class="comment"># tuple</span></span><br><span class="line"></span><br><span class="line">x = (<span class="number">1</span>, )</span><br><span class="line">type(x) <span class="comment"># tuple</span></span><br></pre></td></tr></table></figure><h2 id="更新和删除一个元组"><a href="#更新和删除一个元组" class="headerlink" title="更新和删除一个元组"></a>更新和删除一个元组</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">num = (<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>)</span><br><span class="line">new = num[:<span class="number">2</span>] + (<span class="number">3</span>, ) + num[<span class="number">4</span>:] <span class="comment"># (1, 2, 3, 5)</span></span><br></pre></td></tr></table></figure>元组有不可更改 (immutable) 的性质,因此不能直接给元组的元素赋值,但是只要元组中的元素可更改 (mutable),那么我们可以直接更改其元素,注意这跟赋值其元素不同。<figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">t = (1, 2, 3, [4, 5, 6])</span><br><span class="line">t[3][0] = 9</span><br><span class="line">t # (1, 2, 3, [9, 5, 6])</span><br></pre></td></tr></table></figure><h2 id="元组相关的操作符"><a href="#元组相关的操作符" class="headerlink" title="元组相关的操作符"></a>元组相关的操作符</h2></li><li>等号操作符:<code>==</code></li><li>连接操作符:<code>+</code></li><li>重复操作符:<code>*</code></li><li>成员关系操作符:<code>in</code>、<code>not in</code></li></ul><p>等号 <code>==</code>,只有成员、成员位置都相同时才返回 <code>True</code>。</p><p>元组拼接有两种方式,用加号 <code>+</code> 和乘号 <code>*</code>,前者首尾拼接,后者复制拼接。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">t1 = (<span class="number">123</span>, <span class="number">456</span>)</span><br><span class="line">t2 = (<span class="number">456</span>, <span class="number">123</span>)</span><br><span class="line">t3 = (<span class="number">123</span>, <span class="number">456</span>)</span><br><span class="line"></span><br><span class="line">print(t1 == t2) <span class="comment"># False</span></span><br><span class="line">print(t1 == t3) <span class="comment"># True</span></span><br><span class="line"></span><br><span class="line">t4 = t1 + t2</span><br><span class="line">print(t4) <span class="comment"># (123, 456, 456, 123)</span></span><br><span class="line"></span><br><span class="line">t5 = t3 * <span class="number">3</span></span><br><span class="line">print(t5) <span class="comment"># (123, 456, 123, 456, 123, 456)</span></span><br><span class="line"></span><br><span class="line">t3 *= <span class="number">3</span></span><br><span class="line">print(t3) <span class="comment"># (123, 456, 123, 456, 123, 456)</span></span><br><span class="line"></span><br><span class="line">print(<span class="number">123</span> <span class="keyword">in</span> t3) <span class="comment"># True</span></span><br><span class="line">print(<span class="number">456</span> <span class="keyword">not</span> <span class="keyword">in</span> t3) <span class="comment"># False</span></span><br></pre></td></tr></table></figure><h2 id="内置方法"><a href="#内置方法" class="headerlink" title="内置方法"></a>内置方法</h2><ul><li><code>count</code>:返回在元组中该元素出现几次<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">t = (<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">2</span>)</span><br><span class="line">t.count(<span class="number">2</span>) <span class="comment"># 2</span></span><br></pre></td></tr></table></figure></li><li><code>index</code>:找到指定元素在元组中的索引<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">t = (<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">2</span>)</span><br><span class="line">t.index(<span class="number">3</span>) <span class="comment"># 2</span></span><br></pre></td></tr></table></figure><h2 id="解压元组"><a href="#解压元组" class="headerlink" title="解压元组"></a>解压元组</h2>如果只想要元组其中几个元素,用通配符 <code>*</code>。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">t = (<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>)</span><br><span class="line">a, b, *rest, c = t</span><br><span class="line">print(a, b, c) <span class="comment"># 1 2 5</span></span><br><span class="line">peint(rest) <span class="comment"># 3, 4</span></span><br></pre></td></tr></table></figure><h1 id="字符串"><a href="#字符串" class="headerlink" title="字符串"></a>字符串</h1><h2 id="字符串的定义"><a href="#字符串的定义" class="headerlink" title="字符串的定义"></a>字符串的定义</h2></li><li>Python 中字符串被定义为引号之间的字符集合。</li><li>Python 支持使用成对的 单引号 或 双引号。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="number">1</span> + <span class="number">2</span>) <span class="comment"># 3</span></span><br><span class="line">print(<span class="string">'1'</span> + <span class="string">'2'</span>) <span class="comment"># 12</span></span><br></pre></td></tr></table></figure></li><li>常用的转义字符:<code>\\</code> 反斜杠符号,<code>\t</code> 横向制表符,<code>\n</code> 换行符<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="string">'let\'s go'</span>) <span class="comment"># let's go</span></span><br><span class="line">print(<span class="string">'C:\\temp'</span>) <span class="comment"># C:\temp</span></span><br></pre></td></tr></table></figure></li><li>原始字符串只需要在字符串前边加一个英文字母 <code>r</code> 即可。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="string">r'C:\temp'</span>) <span class="comment"># C:\temp</span></span><br></pre></td></tr></table></figure><h2 id="字符串的切片与拼接"><a href="#字符串的切片与拼接" class="headerlink" title="字符串的切片与拼接"></a>字符串的切片与拼接</h2></li><li>类似于元组具有不可修改性;</li><li>从 0 开始;</li><li>切片通常写成 <code>start:end</code> 这种形式,包括 <code>start</code> 索引对应的元素,不包括<code>end</code> 索引对应的元素;</li><li>索引值可正可负,正索引从 0 开始,从左往右;负索引从 -1 开始,从右往左。使用负数索引时,会从最后一个元素开始计数。最后一个元素的位置编号是 -1。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">s1 = <span class="string">'Hello World'</span></span><br><span class="line">s1[:<span class="number">5</span>] <span class="comment"># Hello</span></span><br></pre></td></tr></table></figure><h2 id="字符串的常用内置方法"><a href="#字符串的常用内置方法" class="headerlink" title="字符串的常用内置方法"></a>字符串的常用内置方法</h2></li><li><code>capitalize()</code>:将字符串的第一个字符转换为大写。</li><li><code>lower()</code>:转换字符串中所有大写字符为小写。</li><li><code>upper()</code>:转换字符串中的小写字母为大写。</li><li><code>swapcase()</code>:将字符串中大写转换为小写,小写转换为大写。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">'hello WORLD'</span></span><br><span class="line">s.capitalize() <span class="comment"># Hello WORLD</span></span><br><span class="line">s.lower() <span class="comment"># hello world</span></span><br><span class="line">s.upper() <span class="comment"># HELLO WORLD</span></span><br><span class="line">s.swapcase() <span class="comment"># HELLO world</span></span><br></pre></td></tr></table></figure></li><li><code>count(str, beg= 0,end=len(string))</code> 返回 <code>str</code> 在 string 里面出现的次数,如果 <code>beg</code> 或者 <code>end</code> 指定则返回指定范围内 <code>str</code> 出现的次数。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">'DAXIExiaoxie'</span></span><br><span class="line">s.count(<span class="string">'xi'</span>) <span class="comment"># 2</span></span><br></pre></td></tr></table></figure></li><li><code>endswith(suffix, beg=0, end=len(string))</code> 检查字符串是否以指定子字符串 <code>suffix</code> 结束,如果是,返回 <code>True</code>,否则返回 <code>False</code>。如果 <code>beg</code> 和 <code>end</code> 指定值,则在指定范围内检查。</li><li><code>startswith(substr, beg=0,end=len(string))</code> 检查字符串是否以指定子字符串 <code>substr</code> 开头,如果是,返回 <code>True</code>,否则返回 <code>False1</code>。如果 <code>beg</code> 和 <code>end</code> 指定值,则在指定范围内检查。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">'hello world'</span></span><br><span class="line">s.startswith(<span class="string">'he'</span>) <span class="comment"># True</span></span><br><span class="line">s.endswith(<span class="string">'LD'</span>) <span class="comment"># False</span></span><br></pre></td></tr></table></figure></li><li><code>find(str, beg=0, end=len(string))</code> 检测 <code>str</code> 是否包含在字符串中,如果指定范围 <code>beg</code> 和 <code>end</code>,则检查是否包含在指定范围内,如果包含,返回开始的索引值,否则返回 -1。</li><li><code>rfind(str, beg=0,end=len(string))</code> 类似于 <code>find()</code> 函数,不过是从右边开始查找。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">'hello world'</span></span><br><span class="line">s.find(<span class="string">'o'</span>) <span class="comment"># 4</span></span><br><span class="line">s.rfind(<span class="string">'o'</span>) <span class="comment"># 7</span></span><br></pre></td></tr></table></figure></li><li><code>isnumeric()</code> 如果字符串中只包含数字字符,则返回 <code>True</code>,否则返回 <code>False</code>。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">'1234'</span></span><br><span class="line">s.isnumeric() <span class="comment"># True</span></span><br><span class="line">s += <span class="string">'a'</span> <span class="comment"># '1234a'</span></span><br><span class="line">s.isnumeric() <span class="comment"># False</span></span><br></pre></td></tr></table></figure></li><li><code>ljust(width[, fillchar])</code> 返回一个原字符串左对齐,并使用fillchar(默认空格)填充至长度width的新字符串。</li><li><code>rjust(width[, fillchar])</code> 返回一个原字符串右对齐,并使用fillchar(默认空格)填充至长度width的新字符串。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">'abcd'</span></span><br><span class="line">print(s.ljust(<span class="number">8</span>, <span class="string">'0'</span>)) <span class="comment"># 'abcd0000'</span></span><br><span class="line">print(s.rjust(<span class="number">8</span>, <span class="string">'0'</span>)) <span class="comment"># '0000abcd'</span></span><br></pre></td></tr></table></figure></li><li>lstrip([chars]) 截掉字符串左边的空格或指定字符。</li><li>rstrip([chars]) 删除字符串末尾的空格或指定字符。</li><li>strip([chars]) 在字符串上执行lstrip()和rstrip()。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">' Hello World '</span></span><br><span class="line">s.lstrip() <span class="comment"># 'Hello World '</span></span><br><span class="line">s.rstrip() <span class="comment"># ' Hello World'</span></span><br><span class="line">s.strip() <span class="comment"># 'Hello World'</span></span><br></pre></td></tr></table></figure></li><li><code>partition(sub)</code> 找到子字符串 <code>sub</code>,把字符串分为一个三元组 <code>(pre_sub,sub,fol_sub)</code>,如果字符串中不包含 <code>sub</code> 则返回 <code>('原字符串','','')</code>。</li><li><code>rpartition(sub)</code> 类似于 <code>partition()</code> 方法,不过是从右边开始查找。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">'Hello World'</span></span><br><span class="line">s.partition(<span class="string">'o'</span>) <span class="comment"># ('Hell', 'o W', 'orld')</span></span><br><span class="line">s.rpartition(<span class="string">'o'</span>) <span class="comment"># ('Hello W', 'o', 'rld')</span></span><br></pre></td></tr></table></figure></li><li><code>replace(old, new [, max])</code> 把 将字符串中的 <code>old</code> 替换成 <code>new</code>,如果 <code>max</code> 指定,则替换不超过 <code>max</code> 次。<figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">s = 'Hello World'</span><br><span class="line">s.replace('o', 'a', 1) # Hella World</span><br></pre></td></tr></table></figure></li><li><code>split(str="", num)</code> 不带参数默认是以空格为分隔符切片字符串,如果 <code>num</code> 参数有设置,则仅分隔 <code>num</code> 个子字符串,返回切片后的子字符串拼接的列表。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">'Hello World'</span></span><br><span class="line">s.split() <span class="comment"># ['Hello', 'World']</span></span><br><span class="line">s.split(<span class="string">'l'</span>, <span class="number">2</span>) <span class="comment"># ['He', '', 'o World']</span></span><br></pre></td></tr></table></figure><h2 id="字符串格式化"><a href="#字符串格式化" class="headerlink" title="字符串格式化"></a>字符串格式化</h2><code>f-string</code> 格式化函数<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line">```</span><br><span class="line"><span class="comment"># 字典</span></span><br><span class="line"><span class="comment">## 可变类型与不可变类型</span></span><br><span class="line">- 序列是以连续的整数为索引,与此不同的是,字典以<span class="string">"关键字"</span>为索引,关键字可以是任意不可变类型,通常用字符串或数值。</span><br><span class="line">- 字典是 Python 唯一的一个 映射类型,字符串、元组、列表属于序列类型。</span><br><span class="line"></span><br><span class="line">那么如何快速判断一个数据类型 X 是不是可变类型的呢?两种方法:</span><br><span class="line">- 麻烦方法:用 `id(X)` 函数,对 `X` 进行某种操作,比较操作前后的 `id`,如果不一样,则 `X` 不可变,如果一样,则 X 可变。</span><br><span class="line">- 便捷方法:用 `hash(X)`,只要不报错,证明 `X` 可被哈希,即不可变,反过来不可被哈希,即可变。</span><br><span class="line">```py</span><br><span class="line">i = <span class="number">1</span></span><br><span class="line">print(id(i)) <span class="comment"># 4376351872</span></span><br><span class="line">i = i + <span class="number">2</span></span><br><span class="line">print(id(i)) <span class="comment"># 4376351936</span></span><br><span class="line"></span><br><span class="line">l = [<span class="number">1</span>, <span class="number">2</span>]</span><br><span class="line">print(id(l)) <span class="comment"># 140478358404224</span></span><br><span class="line">l.append(<span class="string">'Python'</span>)</span><br><span class="line">print(id(l)) <span class="comment"># 140478358404224</span></span><br></pre></td></tr></table></figure></li><li>整数 <code>i</code> 在加 1 之后的 <code>id</code> 和之前不一样,因此加完之后的这个 <code>i</code> (虽然名字没变),但不是加之前的那个 <code>i</code> 了,因此整数是不可变类型。</li><li>列表 <code>l</code> 在附加 <code>'Python'</code> 之后的 <code>id</code> 和之前一样,因此列表是可变类型。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">print(hash(<span class="string">'Name'</span>)) <span class="comment"># 2936281444635265970</span></span><br><span class="line">print(hash((<span class="number">1</span>, <span class="number">2</span>, <span class="string">'Python'</span>))) <span class="comment"># 5769585943857102932</span></span><br><span class="line">print(hash([<span class="number">1</span>, <span class="number">2</span>, <span class="string">'Python'</span>]))</span><br><span class="line">---------------------------------------------------------------------------</span><br><span class="line">TypeError Traceback (most recent call last)</span><br><span class="line"><ipython-input<span class="number">-146</span>-f9149cd3bfae> <span class="keyword">in</span> <module></span><br><span class="line"> <span class="number">3</span> print(hash((<span class="number">1</span>, <span class="number">2</span>, <span class="string">'Python'</span>))) <span class="comment"># 1704535747474881831</span></span><br><span class="line"> <span class="number">4</span> </span><br><span class="line">----> 5 print(hash([1, 2, 'Python']))</span><br><span class="line"></span><br><span class="line">TypeError: unhashable type: <span class="string">'list'</span></span><br></pre></td></tr></table></figure></li><li>数值、字符和元组 都能被哈希,因此它们是不可变类型。</li><li>列表、集合、字典不能被哈希,因此它是可变类型。<h2 id="字典的定义"><a href="#字典的定义" class="headerlink" title="字典的定义"></a>字典的定义</h2></li></ul><p>字典 是无序的 键:值(<code>key:value</code>)对集合,键必须是互不相同的(在同一个字典之内)。</p><ul><li><code>dict</code> 内部存放的顺序和 <code>key</code> 放入的顺序是没有关系的。</li><li><code>dict</code> 查找和插入的速度极快,不会随着 <code>key</code> 的增加而增加,但是需要占用大量的内存。<br>字典 定义语法为 <code>{元素1, 元素2, ..., 元素n}</code></li><li>其中每一个元素是一个「键值对」– 键:值 (<code>key:value</code>)</li><li>关键点是大括号 <code>{}</code>,逗号 <code>,</code>和冒号 <code>:</code></li><li>大括号 – 把所有元素绑在一起</li><li>逗号 – 将每个键值对分开</li><li>冒号 – 将键和值分开<h2 id="创建和访问字典"><a href="#创建和访问字典" class="headerlink" title="创建和访问字典"></a>创建和访问字典</h2>如果我们取的键在字典中不存在,会直接报错KeyError。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">dct = {<span class="string">'a'</span>:<span class="string">'hello'</span>, <span class="string">'b'</span>:<span class="string">'world'</span>} <span class="comment"># {'a':'hello', 'b':'world'}</span></span><br><span class="line">dct[<span class="string">'c'</span>]</span><br><span class="line">---------------------------------------------------------------------------</span><br><span class="line">KeyError Traceback (most recent call last)</span><br><span class="line"><ipython-input<span class="number">-147</span><span class="number">-556</span>b5f4b8778> <span class="keyword">in</span> <module></span><br><span class="line"> <span class="number">1</span> dct = {<span class="string">'a'</span>:<span class="string">'hello'</span>, <span class="string">'b'</span>:<span class="string">'world'</span>} <span class="comment"># {'a':'hello', 'b':'world'}</span></span><br><span class="line">----> 2 dct['c']</span><br><span class="line"></span><br><span class="line">KeyError: <span class="string">'c'</span></span><br></pre></td></tr></table></figure></li><li><code>dict()</code> 创建一个空的字典。</li></ul><p>通过 <code>key</code> 直接把数据放入字典中,但一个 <code>key</code> 只能对应一个 <code>value</code>,多次对一个 <code>key</code> 放入 <code>value</code>,后面的值会把前面的值冲掉。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">dic = dict()</span><br><span class="line">dic[<span class="string">'a'</span>] = <span class="number">1</span></span><br><span class="line">dic[<span class="string">'b'</span>] = <span class="number">2</span></span><br><span class="line">dic[<span class="string">'c'</span>] = <span class="number">3</span></span><br><span class="line"></span><br><span class="line">print(dic) <span class="comment"># {'a': 1, 'b': 2, 'c': 3}</span></span><br><span class="line"></span><br><span class="line">dic[<span class="string">'a'</span>] = <span class="number">11</span></span><br><span class="line">print(dic) <span class="comment"># {'a': 11, 'b': 2, 'c': 3}</span></span><br><span class="line"></span><br><span class="line">dic[<span class="string">'d'</span>] = <span class="number">4</span></span><br><span class="line">print(dic) <span class="comment"># {'a': 11, 'b': 2, 'c': 3, 'd': 4}</span></span><br></pre></td></tr></table></figure><ul><li><code>dict(mapping)</code> 通过键值对创建字典<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">dic1 = dict([(<span class="string">'apple'</span>, <span class="number">4139</span>), (<span class="string">'peach'</span>, <span class="number">4127</span>), (<span class="string">'cherry'</span>, <span class="number">4098</span>)])</span><br><span class="line">print(dic1) <span class="comment"># {'cherry': 4098, 'apple': 4139, 'peach': 4127}</span></span><br></pre></td></tr></table></figure></li><li><code>dict(**kwargs)</code> 通过 <code>name=value</code> 方式创建字典<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">dic = dict(name=<span class="string">'Tom'</span>, age=<span class="number">10</span>)</span><br><span class="line">print(dic) <span class="comment"># {'name': 'Tom', 'age': 10}</span></span><br></pre></td></tr></table></figure><h2 id="字典的内置方法"><a href="#字典的内置方法" class="headerlink" title="字典的内置方法"></a>字典的内置方法</h2></li><li><code>dict.fromkeys(seq[, value])</code> 用于创建一个新字典,以序列 <code>seq</code> 中元素做字典的键,<code>value</code> 为字典所有键对应的初始值。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">seq = (<span class="string">'name'</span>, <span class="string">'age'</span>, <span class="string">'sex'</span>)</span><br><span class="line">dic1 = dict.fromkeys(seq)</span><br><span class="line">print(dic1)</span><br><span class="line"><span class="comment"># {'name': None, 'age': None, 'sex': None}</span></span><br><span class="line"></span><br><span class="line">dic2 = dict.fromkeys(seq, <span class="number">10</span>)</span><br><span class="line">print(dic2)</span><br><span class="line"><span class="comment"># {'name': 10, 'age': 10, 'sex': 10}</span></span><br></pre></td></tr></table></figure></li><li><code>dict.keys()</code> 返回一个可迭代对象,可以使用 <code>list()</code> 来转换为列表,列表为字典中的所有键。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">dic = {<span class="string">'Name'</span>: <span class="string">'lsgogroup'</span>, <span class="string">'Age'</span>: <span class="number">7</span>}</span><br><span class="line">print(dic.keys()) <span class="comment"># dict_keys(['Name', 'Age'])</span></span><br><span class="line">lst = list(dic.keys()) <span class="comment"># 转换为列表</span></span><br><span class="line">print(lst) <span class="comment"># ['Name', 'Age']</span></span><br></pre></td></tr></table></figure></li><li><code>dict.values()</code> 返回一个迭代器,可以使用 <code>list()</code> 来转换为列表,列表为字典中的所有值。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">dic = {<span class="string">'Sex'</span>: <span class="string">'female'</span>, <span class="string">'Age'</span>: <span class="number">7</span>, <span class="string">'Name'</span>: <span class="string">'Zara'</span>}</span><br><span class="line">print(dic.values()) <span class="comment"># dict_values(['female', 7, 'Zara'])</span></span><br><span class="line">print(list(dic.values()))<span class="comment"># [7, 'female', 'Zara']</span></span><br></pre></td></tr></table></figure></li><li><code>dict.items()</code> 以列表返回可遍历的 (键, 值) 元组数组。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">dic = {<span class="string">'Name'</span>: <span class="string">'Lsgogroup'</span>, <span class="string">'Age'</span>: <span class="number">7</span>}</span><br><span class="line">print(dic.items()) <span class="comment"># dict_items([('Name', 'Lsgogroup'), ('Age', 7)])</span></span><br><span class="line"></span><br><span class="line">print(tuple(dic.items())) <span class="comment"># (('Name', 'Lsgogroup'), ('Age', 7))</span></span><br><span class="line"></span><br><span class="line">print(list(dic.items())) <span class="comment"># [('Name', 'Lsgogroup'), ('Age', 7)]</span></span><br></pre></td></tr></table></figure></li><li><code>dict.get(key, default=None)</code> 返回指定键的值,如果值不在字典中返回默认值。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">dic = {<span class="string">'Name'</span>: <span class="string">'Lsgogroup'</span>, <span class="string">'Age'</span>: <span class="number">27</span>}</span><br><span class="line">print(<span class="string">"Age 值为 : %s"</span> % dic.get(<span class="string">'Age'</span>)) <span class="comment"># Age 值为 : 27</span></span><br><span class="line">print(<span class="string">"Sex 值为 : %s"</span> % dic.get(<span class="string">'Sex'</span>, <span class="string">"NA"</span>)) <span class="comment"># Sex 值为 : NA</span></span><br><span class="line">print(dic) <span class="comment"># {'Name': 'Lsgogroup', 'Age': 27}</span></span><br></pre></td></tr></table></figure></li><li><code>dict.setdefault(key, default=None)</code> 和 <code>get()</code> 方法 类似, 如果键不存在于字典中,将会添加键并将值设为默认值。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">dic = {<span class="string">'Name'</span>: <span class="string">'Lsgogroup'</span>, <span class="string">'Age'</span>: <span class="number">7</span>}</span><br><span class="line">print(<span class="string">"Age 键的值为 : %s"</span> % dic.setdefault(<span class="string">'Age'</span>, <span class="literal">None</span>)) <span class="comment"># Age 键的值为 : 7</span></span><br><span class="line">print(<span class="string">"Sex 键的值为 : %s"</span> % dic.setdefault(<span class="string">'Sex'</span>, <span class="literal">None</span>)) <span class="comment"># Sex 键的值为 : None</span></span><br><span class="line">print(dic) </span><br><span class="line"><span class="comment"># {'Age': 7, 'Name': 'Lsgogroup', 'Sex': None}</span></span><br></pre></td></tr></table></figure></li><li><code>dict.pop(key[,default])</code> 删除字典给定键 <code>key</code> 所对应的值,返回值为被删除的值。<code>key</code> 值必须给出。若 <code>key</code> 不存在,则返回 <code>default</code> 值。</li><li><code>del dict[key]</code> 删除字典给定键 <code>key</code> 所对应的值。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">dic1 = {<span class="number">1</span>: <span class="string">"a"</span>, <span class="number">2</span>: [<span class="number">1</span>, <span class="number">2</span>]}</span><br><span class="line">print(dic1.pop(<span class="number">1</span>), dic1) <span class="comment"># a {2: [1, 2]}</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 设置默认值,必须添加,否则报错</span></span><br><span class="line">print(dic1.pop(<span class="number">3</span>, <span class="string">"nokey"</span>), dic1) <span class="comment"># nokey {2: [1, 2]}</span></span><br></pre></td></tr></table></figure></li><li><code>dict.popitem()</code> 随机返回并删除字典中的一对键和值,如果字典已经为空,却调用了此方法,就报出 <code>KeyError</code> 异常。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">dic1 = {<span class="number">1</span>: <span class="string">"a"</span>, <span class="number">2</span>: [<span class="number">1</span>, <span class="number">2</span>]}</span><br><span class="line">print(dic1.popitem()) <span class="comment"># {2: [1, 2]}</span></span><br><span class="line">print(dic1) <span class="comment"># (1, 'a')</span></span><br></pre></td></tr></table></figure></li><li>dict.clear()用于删除字典内所有元素。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">dic = {<span class="string">'Name'</span>: <span class="string">'Zara'</span>, <span class="string">'Age'</span>: <span class="number">7</span>}</span><br><span class="line">print(<span class="string">"字典长度 : %d"</span> % len(dic)) <span class="comment"># 字典长度 : 2</span></span><br><span class="line">dic.clear()</span><br><span class="line">print(<span class="string">"字典删除后长度 : %d"</span> % len(dic)) <span class="comment"># 字典删除后长度 : 0</span></span><br></pre></td></tr></table></figure></li><li><code>dict.copy()</code> 返回一个字典的浅复制。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">dic1 = {<span class="string">'Name'</span>: <span class="string">'Lsgogroup'</span>, <span class="string">'Age'</span>: <span class="number">7</span>, <span class="string">'Class'</span>: <span class="string">'First'</span>}</span><br><span class="line">dic2 = dic1.copy()</span><br><span class="line">print(<span class="string">"dic2"</span>) <span class="comment"># {'Age': 7, 'Name': 'Lsgogroup', 'Class': 'First'}</span></span><br></pre></td></tr></table></figure></li><li><code>dict.update(dict2)</code> 把字典参数 <code>dict2</code> 的 <code>key:value</code> 对更新到字典 <code>dict</code> 里。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">dic = {<span class="string">'Name'</span>: <span class="string">'Lsgogroup'</span>, <span class="string">'Age'</span>: <span class="number">7</span>}</span><br><span class="line">dic2 = {<span class="string">'Sex'</span>: <span class="string">'female'</span>, <span class="string">'Age'</span>: <span class="number">8</span>}</span><br><span class="line">dic.update(dic2)</span><br><span class="line">print(dic) </span><br><span class="line">{<span class="string">'Name'</span>: <span class="string">'Lsgogroup'</span>, <span class="string">'Age'</span>: <span class="number">8</span>, <span class="string">'Sex'</span>: <span class="string">'female'</span>}</span><br></pre></td></tr></table></figure><h1 id="集合"><a href="#集合" class="headerlink" title="集合"></a>集合</h1>Python 中 <code>set</code> 与 <code>dict</code> 类似,也是一组 <code>key</code> 的集合,但不存储 <code>value</code>。由于<code>key</code> 不能重复,所以在 <code>set</code> 中,没有重复的 <code>key</code>。</li></ul><p>注意,<code>key</code> 为不可变类型,即可哈希的值。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">num = {}</span><br><span class="line">print(type(num)) <span class="comment"># <class 'dict'></span></span><br><span class="line">num = {<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>}</span><br><span class="line">print(type(num)) <span class="comment"># <class 'set'></span></span><br></pre></td></tr></table></figure><h2 id="集合的创建"><a href="#集合的创建" class="headerlink" title="集合的创建"></a>集合的创建</h2><ul><li><p>先创建对象再加入元素。</p></li><li><p>在创建空集合的时候只能使用 <code>s = set()</code>,因为 <code>s = {}</code> 创建的是空字典。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">basket = set()</span><br><span class="line">basket.add(<span class="string">'apple'</span>)</span><br><span class="line">basket.add(<span class="string">'banana'</span>)</span><br><span class="line">print(basket) <span class="comment"># {'banana', 'apple'}</span></span><br></pre></td></tr></table></figure></li><li><p>直接把一堆元素用花括号括起来 <code>{元素1, 元素2, ..., 元素n}</code>。</p></li><li><p>重复元素在 <code>set</code> 中会被自动被过滤。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">basket = {<span class="string">'apple'</span>, <span class="string">'orange'</span>, <span class="string">'apple'</span>, <span class="string">'pear'</span>, <span class="string">'orange'</span>, <span class="string">'banana'</span>}</span><br><span class="line">print(basket) <span class="comment"># {'banana', 'apple', 'pear', 'orange'}</span></span><br></pre></td></tr></table></figure></li><li><p>使用 <code>set(value)</code> 工厂函数,把列表或元组转换成集合。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">a = set(<span class="string">'abracadabra'</span>)</span><br><span class="line">print(a) </span><br><span class="line"><span class="comment"># {'r', 'b', 'd', 'c', 'a'}</span></span><br><span class="line"></span><br><span class="line">b = set((<span class="string">"Google"</span>, <span class="string">"Lsgogroup"</span>, <span class="string">"Taobao"</span>, <span class="string">"Taobao"</span>))</span><br><span class="line">print(b) </span><br><span class="line"><span class="comment"># {'Taobao', 'Lsgogroup', 'Google'}</span></span><br><span class="line"></span><br><span class="line">c = set([<span class="string">"Google"</span>, <span class="string">"Lsgogroup"</span>, <span class="string">"Taobao"</span>, <span class="string">"Google"</span>])</span><br><span class="line">print(c) </span><br><span class="line"><span class="comment"># {'Taobao', 'Lsgogroup', 'Google'}</span></span><br></pre></td></tr></table></figure><h2 id="访问集合中的值"><a href="#访问集合中的值" class="headerlink" title="访问集合中的值"></a>访问集合中的值</h2></li><li><p>可以使用 <code>len()</code> 內建函数得到集合的大小。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">s = set([<span class="string">'Google'</span>, <span class="string">'Baidu'</span>, <span class="string">'Taobao'</span>])</span><br><span class="line">print(len(s)) <span class="comment"># 3</span></span><br></pre></td></tr></table></figure></li><li><p>可以使用 <code>for</code> 把集合中的数据一个个读取出来。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">s = set([<span class="string">'Google'</span>, <span class="string">'Baidu'</span>, <span class="string">'Taobao'</span>])</span><br><span class="line"><span class="keyword">for</span> item <span class="keyword">in</span> s:</span><br><span class="line"> print(item) </span><br><span class="line"><span class="comment"># Baidu</span></span><br><span class="line"><span class="comment"># Google</span></span><br><span class="line"><span class="comment"># Taobao</span></span><br></pre></td></tr></table></figure></li><li><p>可以通过 <code>in</code> 或 <code>not in</code> 判断一个元素是否在集合中已经存在。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">s = set([<span class="string">'Google'</span>, <span class="string">'Baidu'</span>, <span class="string">'Taobao'</span>])</span><br><span class="line">print(<span class="string">'Taobao'</span> <span class="keyword">in</span> s) <span class="comment"># True</span></span><br><span class="line">print(<span class="string">'Facebook'</span> <span class="keyword">not</span> <span class="keyword">in</span> s) <span class="comment"># True</span></span><br></pre></td></tr></table></figure><h2 id="集合的内置方法"><a href="#集合的内置方法" class="headerlink" title="集合的内置方法"></a>集合的内置方法</h2></li><li><p><code>set.add(elmnt)</code> 用于给集合添加元素,如果添加的元素在集合中已存在,则不执行任何操作。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">fruits = {<span class="string">"apple"</span>, <span class="string">"banana"</span>, <span class="string">"cherry"</span>}</span><br><span class="line">fruits.add(<span class="string">"orange"</span>)</span><br><span class="line">print(fruits) </span><br><span class="line"><span class="comment"># {'orange', 'cherry', 'banana', 'apple'}</span></span><br><span class="line"></span><br><span class="line">fruits.add(<span class="string">"apple"</span>)</span><br><span class="line">print(fruits) </span><br><span class="line"><span class="comment"># {'orange', 'cherry', 'banana', 'apple'}</span></span><br></pre></td></tr></table></figure></li><li><p><code>set.update(set)</code> 用于修改当前集合,可以添加新的元素或集合到当前集合中,如果添加的元素在集合中已存在,则该元素只会出现一次,重复的会忽略。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">x = {<span class="string">"apple"</span>, <span class="string">"banana"</span>, <span class="string">"cherry"</span>}</span><br><span class="line">y = {<span class="string">"google"</span>, <span class="string">"baidu"</span>, <span class="string">"apple"</span>}</span><br><span class="line">x.update(y)</span><br><span class="line">print(x) <span class="comment"># {'cherry', 'banana', 'apple', 'google', 'baidu'}</span></span><br></pre></td></tr></table></figure></li><li><p><code>set.remove(item)</code> 用于移除集合中的指定元素。如果元素不存在,则会发生错误。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">fruits = {<span class="string">"apple"</span>, <span class="string">"banana"</span>, <span class="string">"cherry"</span>}</span><br><span class="line">fruits.remove(<span class="string">"banana"</span>)</span><br><span class="line">print(fruits) <span class="comment"># {'apple', 'cherry'}</span></span><br></pre></td></tr></table></figure></li><li><p><code>set.discard(value)</code> 用于移除指定的集合元素。<code>remove()</code> 方法在移除一个不存在的元素时会发生错误,而 <code>discard()</code> 方法不会。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">fruits = {<span class="string">"apple"</span>, <span class="string">"banana"</span>, <span class="string">"cherry"</span>}</span><br><span class="line">fruits.discard(<span class="string">"banana"</span>)</span><br><span class="line">print(fruits) <span class="comment"># {'apple', 'cherry'}</span></span><br></pre></td></tr></table></figure></li><li><p><code>set.pop()</code> 用于随机移除一个元素。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">fruits = {<span class="string">"apple"</span>, <span class="string">"banana"</span>, <span class="string">"cherry"</span>}</span><br><span class="line">x = fruits.pop()</span><br><span class="line">print(fruits) <span class="comment"># {'cherry', 'apple'}</span></span><br><span class="line">print(x) <span class="comment"># banana</span></span><br></pre></td></tr></table></figure><p>由于 set 是无序和无重复元素的集合,所以两个或多个 set 可以做数学意义上的集合操作。</p></li><li><p><code>set.intersection(set1, set2)</code> 返回两个集合的交集。</p></li><li><p><code>set1 & set2</code> 返回两个集合的交集。</p></li><li><p><code>set.intersection_update(set1, set2)</code> 交集,在原始的集合上移除不重叠的元素。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">a = set(<span class="string">'abracadabra'</span>)</span><br><span class="line">b = set(<span class="string">'alacazam'</span>)</span><br><span class="line">print(a) <span class="comment"># {'r', 'a', 'c', 'b', 'd'}</span></span><br><span class="line">print(b) <span class="comment"># {'c', 'a', 'l', 'm', 'z'}</span></span><br><span class="line"></span><br><span class="line">c = a.intersection(b)</span><br><span class="line">print(c) <span class="comment"># {'a', 'c'}</span></span><br><span class="line">print(a & b) <span class="comment"># {'c', 'a'}</span></span><br><span class="line">print(a) <span class="comment"># {'a', 'r', 'c', 'b', 'd'}</span></span><br><span class="line"></span><br><span class="line">a.intersection_update(b)</span><br><span class="line">print(a) <span class="comment"># {'a', 'c'}</span></span><br></pre></td></tr></table></figure></li><li><p><code>set.union(set1, set2)</code> 返回两个集合的并集。</p></li><li><p><code>set1 | set2</code> 返回两个集合的并集。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">a = set(<span class="string">'abracadabra'</span>)</span><br><span class="line">b = set(<span class="string">'alacazam'</span>)</span><br><span class="line">print(a) <span class="comment"># {'r', 'a', 'c', 'b', 'd'}</span></span><br><span class="line">print(b) <span class="comment"># {'c', 'a', 'l', 'm', 'z'}</span></span><br><span class="line"></span><br><span class="line">print(a | b) <span class="comment"># {'l', 'd', 'm', 'b', 'a', 'r', 'z', 'c'}</span></span><br><span class="line"></span><br><span class="line">c = a.union(b)</span><br><span class="line">print(c) <span class="comment"># {'c', 'a', 'd', 'm', 'r', 'b', 'z', 'l'}</span></span><br></pre></td></tr></table></figure></li><li><p><code>set.difference(set)</code> 返回集合的差集。</p></li><li><p><code>set1 - set2</code> 返回集合的差集。</p></li><li><p><code>set.difference_update(set)</code> 集合的差集,直接在原来的集合中移除元素,没有返回值。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">a = set(<span class="string">'abracadabra'</span>)</span><br><span class="line">b = set(<span class="string">'alacazam'</span>)</span><br><span class="line">print(a) <span class="comment"># {'r', 'a', 'c', 'b', 'd'}</span></span><br><span class="line">print(b) <span class="comment"># {'c', 'a', 'l', 'm', 'z'}</span></span><br><span class="line"></span><br><span class="line">c = a.difference(b)</span><br><span class="line">print(c) <span class="comment"># {'b', 'd', 'r'}</span></span><br><span class="line">print(a - b) <span class="comment"># {'d', 'b', 'r'}</span></span><br><span class="line"></span><br><span class="line">print(a) <span class="comment"># {'r', 'd', 'c', 'a', 'b'}</span></span><br><span class="line">a.difference_update(b)</span><br><span class="line">print(a) <span class="comment"># {'d', 'r', 'b'}</span></span><br></pre></td></tr></table></figure></li><li><p><code>set.issubset(set)</code> 判断集合是不是被其他集合包含,如果是则返回 <code>True</code>,否则返回 <code>False</code>。</p></li><li><p><code>set1 <= set2</code> 判断集合是不是被其他集合包含,如果是则返回 <code>True</code>,否则返回 <code>False</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">x = {<span class="string">"a"</span>, <span class="string">"b"</span>, <span class="string">"c"</span>}</span><br><span class="line">y = {<span class="string">"f"</span>, <span class="string">"e"</span>, <span class="string">"d"</span>, <span class="string">"c"</span>, <span class="string">"b"</span>, <span class="string">"a"</span>}</span><br><span class="line">z = x.issubset(y)</span><br><span class="line">print(z) <span class="comment"># True</span></span><br><span class="line">print(x <= y) <span class="comment"># True</span></span><br></pre></td></tr></table></figure></li><li><p><code>set.issuperset(set)</code> 用于判断集合是不是包含其他集合,如果是则返回 <code>True</code>,否则返回 <code>False</code>。</p></li><li><p><code>set1 >= set2</code> 判断集合是不是包含其他集合,如果是则返回 <code>True</code>,否则返回 <code>False</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">x = {<span class="string">"f"</span>, <span class="string">"e"</span>, <span class="string">"d"</span>, <span class="string">"c"</span>, <span class="string">"b"</span>, <span class="string">"a"</span>}</span><br><span class="line">y = {<span class="string">"a"</span>, <span class="string">"b"</span>, <span class="string">"c"</span>}</span><br><span class="line">z = x.issuperset(y)</span><br><span class="line">print(z) <span class="comment"># True</span></span><br><span class="line">print(x >= y) <span class="comment"># True</span></span><br></pre></td></tr></table></figure></li><li><p><code>set.isdisjoint(set)</code> 用于判断两个集合是不是不相交,如果是返回 <code>True</code>,否则返回 <code>False</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">x = {<span class="string">"f"</span>, <span class="string">"e"</span>, <span class="string">"d"</span>, <span class="string">"c"</span>, <span class="string">"b"</span>}</span><br><span class="line">y = {<span class="string">"a"</span>, <span class="string">"b"</span>, <span class="string">"c"</span>}</span><br><span class="line">z = x.isdisjoint(y)</span><br><span class="line">print(z) <span class="comment"># False</span></span><br><span class="line"></span><br><span class="line">x = {<span class="string">"f"</span>, <span class="string">"e"</span>, <span class="string">"d"</span>, <span class="string">"m"</span>, <span class="string">"g"</span>}</span><br><span class="line">y = {<span class="string">"a"</span>, <span class="string">"b"</span>, <span class="string">"c"</span>}</span><br><span class="line">z = x.isdisjoint(y)</span><br><span class="line">print(z) <span class="comment"># True</span></span><br></pre></td></tr></table></figure><h2 id="不可变集合"><a href="#不可变集合" class="headerlink" title="不可变集合"></a>不可变集合</h2></li><li><p><code>frozenset([iterable])</code> 返回一个冻结的集合,冻结后集合不能再添加或删除任何元素。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">a = frozenset(range(<span class="number">10</span>)) <span class="comment"># 生成一个新的不可变集合</span></span><br><span class="line">print(a) <span class="comment"># frozenset({0, 1, 2, 3, 4, 5, 6, 7, 8, 9})</span></span><br><span class="line"></span><br><span class="line">b = frozenset(<span class="string">'lsgogroup'</span>)</span><br><span class="line">print(b) <span class="comment"># frozenset({'g', 's', 'p', 'r', 'u', 'o', 'l'})</span></span><br></pre></td></tr></table></figure><h1 id="序列"><a href="#序列" class="headerlink" title="序列"></a>序列</h1><h2 id="针对序列的内置函数"><a href="#针对序列的内置函数" class="headerlink" title="针对序列的内置函数"></a>针对序列的内置函数</h2></li><li><p><code>list(sub)</code> 把一个可迭代对象转换为列表。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">a = list()</span><br><span class="line">print(a) <span class="comment"># []</span></span><br><span class="line"></span><br><span class="line">b = <span class="string">'I Love LsgoGroup'</span></span><br><span class="line">b = list(b)</span><br><span class="line">print(b) </span><br><span class="line"><span class="comment"># ['I', ' ', 'L', 'o', 'v', 'e', ' ', 'L', 's', 'g', 'o', 'G', 'r', 'o', 'u', 'p']</span></span><br><span class="line"></span><br><span class="line">c = (<span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">8</span>)</span><br><span class="line">c = list(c)</span><br><span class="line">print(c) <span class="comment"># [1, 1, 2, 3, 5, 8]</span></span><br></pre></td></tr></table></figure></li><li><p><code>tuple(sub)</code> 把一个可迭代对象转换为元组。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">a = tuple()</span><br><span class="line">print(a) <span class="comment"># ()</span></span><br><span class="line"></span><br><span class="line">b = <span class="string">'I Love LsgoGroup'</span></span><br><span class="line">b = tuple(b)</span><br><span class="line">print(b) </span><br><span class="line"><span class="comment"># ('I', ' ', 'L', 'o', 'v', 'e', ' ', 'L', 's', 'g', 'o', 'G', 'r', 'o', 'u', 'p')</span></span><br><span class="line"></span><br><span class="line">c = [<span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">8</span>]</span><br><span class="line">c = tuple(c)</span><br><span class="line">print(c) <span class="comment"># (1, 1, 2, 3, 5, 8)</span></span><br></pre></td></tr></table></figure></li><li><p><code>str(obj)</code> 把 <code>obj</code> 对象转换为字符串</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">a = <span class="number">123</span></span><br><span class="line">a = str(a)</span><br><span class="line">print(a) <span class="comment"># 123</span></span><br></pre></td></tr></table></figure></li><li><p><code>len(s)</code> 返回对象(字符、列表、元组等)长度或元素个数。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">a = list()</span><br><span class="line">print(len(a)) <span class="comment"># 0</span></span><br><span class="line"></span><br><span class="line">b = (<span class="string">'I'</span>, <span class="string">' '</span>, <span class="string">'L'</span>, <span class="string">'o'</span>, <span class="string">'v'</span>, <span class="string">'e'</span>, <span class="string">' '</span>, <span class="string">'L'</span>, <span class="string">'s'</span>, <span class="string">'g'</span>, <span class="string">'o'</span>, <span class="string">'G'</span>, <span class="string">'r'</span>, <span class="string">'o'</span>, <span class="string">'u'</span>, <span class="string">'p'</span>)</span><br><span class="line">print(len(b)) <span class="comment"># 16</span></span><br><span class="line"></span><br><span class="line">c = <span class="string">'I Love LsgoGroup'</span></span><br><span class="line">print(len(c)) <span class="comment"># 16</span></span><br></pre></td></tr></table></figure></li><li><p><code>max(sub)</code> 返回序列或者参数集合中的最大值</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">print(max(<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>)) <span class="comment"># 5</span></span><br><span class="line">print(max([<span class="number">-8</span>, <span class="number">99</span>, <span class="number">3</span>, <span class="number">7</span>, <span class="number">83</span>])) <span class="comment"># 99</span></span><br><span class="line">print(max(<span class="string">'IloveLsgoGroup'</span>)) <span class="comment"># v</span></span><br></pre></td></tr></table></figure></li><li><p><code>min(sub)</code> 返回序列或参数集合中的最小值</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">print(min(<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>)) <span class="comment"># 1</span></span><br><span class="line">print(min([<span class="number">-8</span>, <span class="number">99</span>, <span class="number">3</span>, <span class="number">7</span>, <span class="number">83</span>])) <span class="comment"># -8</span></span><br><span class="line">print(min(<span class="string">'IloveLsgoGroup'</span>)) <span class="comment"># G</span></span><br></pre></td></tr></table></figure></li><li><p><code>sum(iterable[, start=0])</code> 返回序列 <code>iterable</code> 与可选参数 <code>start</code> 的总和。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">print(sum([<span class="number">1</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">7</span>, <span class="number">9</span>])) <span class="comment"># 25</span></span><br><span class="line">print(sum([<span class="number">1</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">7</span>, <span class="number">9</span>], <span class="number">10</span>)) <span class="comment"># 35</span></span><br><span class="line">print(sum((<span class="number">1</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">7</span>, <span class="number">9</span>))) <span class="comment"># 25</span></span><br><span class="line">print(sum((<span class="number">1</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">7</span>, <span class="number">9</span>), <span class="number">20</span>)) <span class="comment"># 45</span></span><br></pre></td></tr></table></figure></li><li><p><code>sorted(iterable, key=None, reverse=False)</code> 对所有可迭代的对象进行排序操作。</p><ul><li><code>iterable</code> – 可迭代对象。</li><li><code>key</code> – 主要是用来进行比较的元素,只有一个参数,具体的函数的参数就是取自于可迭代对象中,指定可迭代对象中的一个元素来进行排序。</li><li><code>reverse</code> – 排序规则,<code>reverse = True</code> 降序 , <code>reverse = False</code> 升序(默认)。</li><li>返回重新排序的列表。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">x = [<span class="number">-8</span>, <span class="number">99</span>, <span class="number">3</span>, <span class="number">7</span>, <span class="number">83</span>]</span><br><span class="line">print(sorted(x)) <span class="comment"># [-8, 3, 7, 83, 99]</span></span><br><span class="line">print(sorted(x, reverse=<span class="literal">True</span>)) <span class="comment"># [99, 83, 7, 3, -8]</span></span><br><span class="line"></span><br><span class="line">t = ({<span class="string">"age"</span>: <span class="number">20</span>, <span class="string">"name"</span>: <span class="string">"a"</span>}, {<span class="string">"age"</span>: <span class="number">25</span>, <span class="string">"name"</span>: <span class="string">"b"</span>}, {<span class="string">"age"</span>: <span class="number">10</span>, <span class="string">"name"</span>: <span class="string">"c"</span>})</span><br><span class="line">x = sorted(t, key=<span class="keyword">lambda</span> a: a[<span class="string">"age"</span>])</span><br><span class="line">print(x)</span><br><span class="line"><span class="comment"># [{'age': 10, 'name': 'c'}, {'age': 20, 'name': 'a'}, {'age': 25, 'name': 'b'}]</span></span><br></pre></td></tr></table></figure></li></ul></li><li><p><code>reversed(seq)</code> 函数返回一个反转的迭代器。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line">s = <span class="string">'lsgogroup'</span></span><br><span class="line">x = reversed(s)</span><br><span class="line">print(type(x)) <span class="comment"># <class 'reversed'></span></span><br><span class="line">print(x) <span class="comment"># <reversed object at 0x000002507E8EC2C8></span></span><br><span class="line">print(list(x))</span><br><span class="line"><span class="comment"># ['p', 'u', 'o', 'r', 'g', 'o', 'g', 's', 'l']</span></span><br><span class="line"></span><br><span class="line">t = (<span class="string">'l'</span>, <span class="string">'s'</span>, <span class="string">'g'</span>, <span class="string">'o'</span>, <span class="string">'g'</span>, <span class="string">'r'</span>, <span class="string">'o'</span>, <span class="string">'u'</span>, <span class="string">'p'</span>)</span><br><span class="line">print(list(reversed(t)))</span><br><span class="line"><span class="comment"># ['p', 'u', 'o', 'r', 'g', 'o', 'g', 's', 'l']</span></span><br><span class="line"></span><br><span class="line">r = range(<span class="number">5</span>, <span class="number">9</span>)</span><br><span class="line">print(list(reversed(r)))</span><br><span class="line"><span class="comment"># [8, 7, 6, 5]</span></span><br><span class="line"></span><br><span class="line">x = [<span class="number">-8</span>, <span class="number">99</span>, <span class="number">3</span>, <span class="number">7</span>, <span class="number">83</span>]</span><br><span class="line">print(list(reversed(x)))</span><br><span class="line"><span class="comment"># [83, 7, 3, 99, -8]</span></span><br></pre></td></tr></table></figure></li><li><p><code>enumerate(sequence, [start=0])</code> 用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 <code>for</code> 循环当中。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">seasons = [<span class="string">'Spring'</span>, <span class="string">'Summer'</span>, <span class="string">'Fall'</span>, <span class="string">'Winter'</span>]</span><br><span class="line">a = list(enumerate(seasons))</span><br><span class="line">print(a) </span><br><span class="line"><span class="comment"># [(0, 'Spring'), (1, 'Summer'), (2, 'Fall'), (3, 'Winter')]</span></span><br><span class="line"></span><br><span class="line">b = list(enumerate(seasons, <span class="number">1</span>))</span><br><span class="line">print(b) </span><br><span class="line"><span class="comment"># [(1, 'Spring'), (2, 'Summer'), (3, 'Fall'), (4, 'Winter')]</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> i, element <span class="keyword">in</span> a:</span><br><span class="line"> print(<span class="string">'{0},{1}'</span>.format(i, element))</span><br><span class="line"><span class="comment"># 0,Spring</span></span><br><span class="line"><span class="comment"># 1,Summer</span></span><br><span class="line"><span class="comment"># 2,Fall</span></span><br><span class="line"><span class="comment"># 3,Winter</span></span><br></pre></td></tr></table></figure></li></ul>]]></content>
<tags>
<tag> python </tag>
</tags>
</entry>
<entry>
<title>[经验总结]取整与取余</title>
<link href="2020/12/19/%E7%BB%8F%E9%AA%8C%E6%80%BB%E7%BB%93-%E5%8F%96%E6%95%B4%E4%B8%8E%E5%8F%96%E4%BD%99/"/>
<url>2020/12/19/%E7%BB%8F%E9%AA%8C%E6%80%BB%E7%BB%93-%E5%8F%96%E6%95%B4%E4%B8%8E%E5%8F%96%E4%BD%99/</url>
<content type="html"><![CDATA[<p>最近开始复习、深挖 Python 基础知识,有机会深入探索一些以前没有想过的事情。</p><p>我们知道,Python 内置函数 <code>int</code> 和 <code>round</code> 可以把一个浮点数取整,比如</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>int(<span class="number">1.1</span>)</span><br><span class="line"><span class="number">1</span></span><br><span class="line"><span class="meta">>>> </span>round(<span class="number">1.1</span>)</span><br><span class="line"><span class="number">1</span></span><br></pre></td></tr></table></figure><p>它们是如何工作的呢?</p><a id="more"></a><h1 id="int-函数"><a href="#int-函数" class="headerlink" title="int 函数"></a><code>int</code> 函数</h1><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">int(x, base=<span class="number">10</span>)</span><br></pre></td></tr></table></figure><p>本文中我们不关注一个数的进制,统一按十进制处理。我们来看一个例子:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">print(int(<span class="number">1.1</span>)) <span class="comment"># 1</span></span><br><span class="line">print(int(<span class="number">1.9</span>)) <span class="comment"># 1</span></span><br><span class="line">print(int(<span class="number">-1.1</span>)) <span class="comment"># -1</span></span><br><span class="line">print(int(<span class="number">-1.9</span>)) <span class="comment"># -1</span></span><br></pre></td></tr></table></figure><p>看起来 <code>int</code> 函数就是简单地截取小数点前面的数值。</p><h1 id="round-函数"><a href="#round-函数" class="headerlink" title="round 函数"></a><code>round</code> 函数</h1><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">round(x, ndigits=<span class="literal">None</span>)</span><br></pre></td></tr></table></figure><p>这里 <code>round</code> 的处理近似于四舍五入,我们先看几个没有争议的例子:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">print(round(<span class="number">1.1</span>)) <span class="comment"># 1</span></span><br><span class="line">print(round(<span class="number">1.9</span>)) <span class="comment"># 2</span></span><br><span class="line">print(round(<span class="number">-1.1</span>)) <span class="comment"># -1</span></span><br><span class="line">print(round(<span class="number">-1.9</span>)) <span class="comment"># -2</span></span><br></pre></td></tr></table></figure><p>现在有意思的时候到了,<code>.5</code> 应该如何进位呢?</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">print(round(<span class="number">1.5</span>)) <span class="comment"># 2</span></span><br><span class="line">print(round(<span class="number">2.5</span>)) <span class="comment"># 2</span></span><br><span class="line">print(round(<span class="number">-1.5</span>)) <span class="comment"># -2</span></span><br><span class="line">print(round(<span class="number">-2.5</span>)) <span class="comment"># -2</span></span><br></pre></td></tr></table></figure><p>根据<a href="https://docs.python.org/3/library/functions.html#round" target="_blank" rel="noopener" title="round 函数">官方文档</a>,rounding 会选择被 2 整除的数,所以 1.5 和 2.5 的 rounding 都是 2。当 <code>ndigits</code> 取非 0 值时,这个原则仍然适用。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">print(round(<span class="number">1.55</span>, <span class="number">1</span>)) <span class="comment"># 1.6</span></span><br><span class="line">print(round(<span class="number">1.45</span>, <span class="number">1</span>)) <span class="comment"># 1.4</span></span><br></pre></td></tr></table></figure><p>上述例子中的 <code>ndigits</code> 在为 0 时为 <code>None</code>,取 0 会有什么变化吗?</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">round(<span class="number">1.1</span>, ndigits=<span class="literal">None</span>) <span class="comment"># 1</span></span><br><span class="line">round(<span class="number">1.1</span>, <span class="number">0</span>) <span class="comment"># 1.0</span></span><br></pre></td></tr></table></figure><p>是有变化的!如果 <code>ndigits=None</code> 或者干脆省略,返回一个整数;如果 <code>ndigits=0</code>,返回一个带有一位小数点的整数。无独有偶,<code>numpy</code>,<code>TensorFlow</code> 和 <code>PyTorch</code> 也有 <code>numpy.round</code>,<code>tensorflow.math.round</code> 和 <code>torch.round</code> 与 Python 原生 <code>round</code> 函数对应,它们与原生函数有什么区别呢?</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">print(np.round(<span class="number">1.5</span>)) <span class="comment"># 2.0</span></span><br><span class="line">print(np.round(<span class="number">2.5</span>)) <span class="comment"># 2.0</span></span><br><span class="line"></span><br><span class="line">print(tf.round(tf.Variable(<span class="number">1.5</span>)).numpy()) <span class="comment"># 2.0</span></span><br><span class="line">print(tf.round(tf.Variable(<span class="number">1.5</span>)).numpy()) <span class="comment"># 2.0</span></span><br><span class="line"></span><br><span class="line">print(torch.round(torch.tensor(<span class="number">1.5</span>)).item()) <span class="comment"># 2.0</span></span><br><span class="line">print(torch.round(torch.tensor(<span class="number">2.5</span>)).item()) <span class="comment"># 2.0</span></span><br></pre></td></tr></table></figure><p>可以看到,numpy,TensorFlow 和 PyTorch 的 <code>round</code> 函数的工作原理与 Python 原生函数相同。顺便提一句,TensorFlow 和 PyTorch 的 <code>round</code> 函数只能取整,返回一个带有一位小数点的整数;numpy 的 <code>round</code> 函数与 Python 原生函数相同,但是变量名不是 <code>nsdigits</code> 而是 <code>decimals</code>。</p><h1 id="Python-里的相除取整(-)与相除取余(-)"><a href="#Python-里的相除取整(-)与相除取余(-)" class="headerlink" title="Python 里的相除取整(//)与相除取余(%)"></a>Python 里的相除取整(<code>//</code>)与相除取余(<code>%</code>)</h1><p>所谓的相除取整和相除取余很好理解,就是一个除法如果不能整除的话就分别取整除部分和余数部分:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="number">123</span> // <span class="number">10</span>) <span class="comment"># 12</span></span><br><span class="line">print(<span class="number">123</span> % <span class="number">10</span>) <span class="comment"># 3</span></span><br></pre></td></tr></table></figure><p>本来很简单的一件小事遇到负数就有意思了:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="number">-123</span> // <span class="number">10</span>) <span class="comment"># -13</span></span><br><span class="line">print(<span class="number">123</span> // <span class="number">-10</span>) <span class="comment"># -13</span></span><br><span class="line">print(<span class="number">-123</span> // <span class="number">-10</span>) <span class="comment"># 12</span></span><br><span class="line">print(<span class="number">-123</span> % <span class="number">10</span>) <span class="comment"># 7</span></span><br><span class="line">print(<span class="number">123</span> % <span class="number">-10</span>) <span class="comment"># -7</span></span><br><span class="line">print(<span class="number">-123</span> % <span class="number">-10</span>) <span class="comment"># -3</span></span><br></pre></td></tr></table></figure><p>这是怎么回事呢?<a href="https://blog.csdn.net/sun___M/article/details/83142126" target="_blank" rel="noopener" title="相除取整与相除取余"></a></p><h2 id="相除取整"><a href="#相除取整" class="headerlink" title="相除取整"></a>相除取整</h2><p>我们把几个除法的结果比较一下:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="number">123</span> / <span class="number">10</span>) <span class="comment"># 12.3</span></span><br><span class="line">print(<span class="number">123</span> // <span class="number">10</span>) <span class="comment"># 12</span></span><br><span class="line">print(int(<span class="number">123</span> / <span class="number">10</span>)) <span class="comment"># 12</span></span><br><span class="line">print(<span class="number">-123</span> / <span class="number">10</span>) <span class="comment"># -12.3</span></span><br><span class="line">print(<span class="number">-123</span> // <span class="number">10</span>) <span class="comment"># -13</span></span><br><span class="line">print(int(<span class="number">-123</span> / <span class="number">10</span>)) <span class="comment"># -12</span></span><br><span class="line">print(<span class="number">-123</span> / <span class="number">-10</span>) <span class="comment"># 12.3</span></span><br><span class="line">print(<span class="number">-123</span> // <span class="number">-10</span>) <span class="comment"># 12</span></span><br><span class="line">print(int(<span class="number">-123</span> / <span class="number">-10</span>)) <span class="comment"># 12</span></span><br></pre></td></tr></table></figure><p>在没有看函数源代码的情况下,我们可以大概说,<code>//</code> 操作为向下取整。如果想要一个负结果的向上取整的结果,可以使用 <code>int</code> 配合普通除法。负数除以一个负数与正数除以一个正数的结果相同。</p><h2 id="相除取余"><a href="#相除取余" class="headerlink" title="相除取余"></a>相除取余</h2><p>还是比较几个除法取余的结果:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="number">123</span> % <span class="number">10</span>) <span class="comment"># 3</span></span><br><span class="line">print(<span class="number">123</span> % <span class="number">-10</span>) <span class="comment"># -7</span></span><br><span class="line">print(<span class="number">-123</span> % <span class="number">10</span>) <span class="comment"># 7</span></span><br><span class="line">print(<span class="number">-123</span> % <span class="number">-10</span>) <span class="comment"># -3</span></span><br></pre></td></tr></table></figure><p>其实在 Python 中,取余的计算公式与别的语言并没有什么区别:<br>$$r = a - n * [a // n]$$<br>其中 <code>r</code> 是余数,<code>a</code> 是被除数,<code>n</code> 是除数。不过在 <code>a // n</code> 这一步,当 <code>a</code> 是负数的时候,上面提到会向下取整,所以有:<br>$$-123 % 10 = -123 - 10 * (-123 // 10) = -123 - 10 * (-13) = 7$$<br>其余的两个相除取余也可以按照此法推演出来。</p>]]></content>
<tags>
<tag> 经验总结 </tag>
</tags>
</entry>
<entry>
<title>[DSU&阿里云天池] Python训练营 Task 1</title>
<link href="2020/12/19/DSU-%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-Python%E8%AE%AD%E7%BB%83%E8%90%A5-Task-1/"/>
<url>2020/12/19/DSU-%E9%98%BF%E9%87%8C%E4%BA%91%E5%A4%A9%E6%B1%A0-Python%E8%AE%AD%E7%BB%83%E8%90%A5-Task-1/</url>
<content type="html"><![CDATA[<p>[TOC]<br>今天参加了由阿里云天池开展的 Python 训练营,借这个机会回顾一下 Python 的基础知识,并深入一些以前没有注意到的点。本文为 task 1。</p><h1 id="变量、运算符与数据类型"><a href="#变量、运算符与数据类型" class="headerlink" title="变量、运算符与数据类型"></a>变量、运算符与数据类型</h1><h2 id="注释"><a href="#注释" class="headerlink" title="注释"></a>注释</h2><h3 id="单行注释"><a href="#单行注释" class="headerlink" title="单行注释"></a>单行注释</h3><p><code>#</code> 表示其后面的整行内容为注释,后面的所有文字会被忽略。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 这是一个注释</span></span><br><span class="line">print(<span class="string">'Hello Word!'</span>) <span class="comment"># '#'也可以放在一条代码的后面</span></span><br><span class="line">Hello Word!</span><br></pre></td></tr></table></figure><a id="more"></a><h3 id="多行注释"><a href="#多行注释" class="headerlink" title="多行注释"></a>多行注释</h3><p>使用 <code>''' ''''</code> 或 <code>""" """</code> 表示多行注释,在三个引号内的内容为注释。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="string">'''</span></span><br><span class="line"><span class="string">这是多行注释,用三个单引号</span></span><br><span class="line"><span class="string">这是多行注释,用三个单引号</span></span><br><span class="line"><span class="string">这是多行注释,用三个单引号</span></span><br><span class="line"><span class="string">'''</span></span><br><span class="line">print(<span class="string">'Hello China'</span>)</span><br><span class="line"><span class="comment"># Hello China</span></span><br><span class="line"></span><br><span class="line"><span class="string">"""</span></span><br><span class="line"><span class="string">这是多行注释,用三个双引号</span></span><br><span class="line"><span class="string">这是多行注释,用三个双引号</span></span><br><span class="line"><span class="string">这是多行注释,用三个双引号</span></span><br><span class="line"><span class="string">"""</span></span><br><span class="line">print(<span class="string">'Hello China'</span>)</span><br><span class="line"><span class="comment"># Hello China</span></span><br></pre></td></tr></table></figure><h2 id="运算符"><a href="#运算符" class="headerlink" title="运算符"></a>运算符</h2><h3 id="算术运算符"><a href="#算术运算符" class="headerlink" title="算术运算符"></a>算术运算符</h3><p>算术运算符有 <code>+</code>(加),<code>-</code>(减),<code>*</code>(乘),<code>/</code>(除),<code>//</code>(除取整),<code>%</code>(除取余)和 <code>**</code>(幂次)。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="number">1</span> + <span class="number">2</span>) <span class="comment"># 3</span></span><br><span class="line">print(<span class="number">4</span> - <span class="number">3</span>) <span class="comment"># 1</span></span><br><span class="line">print(<span class="number">5</span> * <span class="number">6</span>) <span class="comment"># 30</span></span><br><span class="line">print(<span class="number">7</span> / <span class="number">8</span>) <span class="comment"># 0.875</span></span><br><span class="line">print(<span class="number">9</span> // <span class="number">10</span>) <span class="comment"># 0,因为 9 不够 1 个 10</span></span><br><span class="line">print(<span class="number">1</span> % <span class="number">2</span>) <span class="comment"># 1,因为 1 除以 2 余 1</span></span><br><span class="line">print(<span class="number">3</span> ** <span class="number">4</span>) <span class="comment"># 81,因为 3 的 4 次幂是 81</span></span><br></pre></td></tr></table></figure><p>相除取整和相除取余在涉及到负数的时候变得复杂,详情请见<a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485633&idx=1&sn=9d29d82c929feabfed60c45898381a3a&chksm=cf06b9fdf87130eb49e334c806024bff962afafc12f33bfbf71aa5cb486197961df9be88116a&token=457736762&lang=zh_CN#rd" target="_blank" rel="noopener">这里</a>。</p><h3 id="赋值运算符"><a href="#赋值运算符" class="headerlink" title="赋值运算符"></a>赋值运算符</h3><p>赋值运算符为算术运算符后面加个 <code>=</code>,在进行运算以后将新值赋给原变量,即为原地操作(in-place)。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line">a = <span class="number">1</span></span><br><span class="line">b = <span class="number">2</span></span><br><span class="line"></span><br><span class="line">a = a + b <span class="comment"># 现在 a 的值为 3</span></span><br><span class="line">a += b <span class="comment"># 新的值被原地赋给了 a</span></span><br><span class="line"></span><br><span class="line">c = <span class="number">3</span></span><br><span class="line">c += <span class="number">1</span> <span class="comment"># 4</span></span><br><span class="line">c -= <span class="number">2</span> <span class="comment"># 2</span></span><br><span class="line">c *= <span class="number">3</span> <span class="comment"># 6</span></span><br><span class="line">c /= <span class="number">4</span> <span class="comment"># 1.5</span></span><br><span class="line"></span><br><span class="line">d = <span class="number">4</span></span><br><span class="line">d **= <span class="number">2</span> <span class="comment"># 16</span></span><br><span class="line">d //= <span class="number">3</span> <span class="comment"># 5</span></span><br><span class="line">d %= <span class="number">4</span> = <span class="number">1</span></span><br></pre></td></tr></table></figure><h3 id="比较运算符"><a href="#比较运算符" class="headerlink" title="比较运算符"></a>比较运算符</h3><p>比较运算符有 <code>></code>(大于),<code><</code>(小于),<code>>=</code>(大于等于),<code><=</code>(小于等于),<code>==</code>(等于)和 <code>!=</code>(不等于),返回 <code>True</code> 或 <code>False</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="number">1</span> > <span class="number">2</span>) <span class="comment"># False</span></span><br><span class="line">print(<span class="number">2</span> < <span class="number">4</span>) <span class="comment"># True</span></span><br><span class="line">print(<span class="number">5</span> == <span class="number">5</span>) <span class="comment"># True</span></span><br><span class="line">print(<span class="number">6</span> != <span class="number">6</span>) <span class="comment"># False</span></span><br></pre></td></tr></table></figure><h3 id="逻辑运算符"><a href="#逻辑运算符" class="headerlink" title="逻辑运算符"></a>逻辑运算符</h3><p>逻辑运算符有 <code>and</code>(和、与),<code>or</code>(或)和 <code>not</code>(非)。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">print(<span class="literal">True</span>) <span class="comment"># True</span></span><br><span class="line">print(<span class="literal">True</span> <span class="keyword">and</span> <span class="literal">False</span>) <span class="comment"># False</span></span><br><span class="line">print(<span class="literal">True</span> <span class="keyword">or</span> <span class="literal">False</span>) <span class="comment"># True</span></span><br><span class="line">print(<span class="keyword">not</span> <span class="literal">True</span>) <span class="comment"># False</span></span><br></pre></td></tr></table></figure><p>只有 <code>and</code> 两边同时为真才返回 <code>True</code>, 取余情况返回 <code>False</code>;只有 <code>or</code> 两边同时为假才返回 <code>False</code>,取余情况返回 <code>True</code>;<code>not</code> 返回与原判断结果相反的结果。</p><h3 id="其它操作符"><a href="#其它操作符" class="headerlink" title="其它操作符"></a>其它操作符</h3><p><code>is</code>(是),<code>in</code>(存在)</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="number">1</span> <span class="keyword">in</span> [<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>] <span class="comment"># true</span></span><br><span class="line"><span class="string">'d'</span> <span class="keyword">in</span> <span class="string">'abc'</span> <span class="comment"># False</span></span><br></pre></td></tr></table></figure><h3 id="is-与-的区别"><a href="#is-与-的区别" class="headerlink" title="is 与 == 的区别"></a><code>is</code> 与 <code>==</code> 的区别</h3><ul><li>is, is not 对比的是两个变量的内存地址</li><li>==, != 对比的是两个变量的值</li><li>比较的两个变量,指向的都是地址不可变的类型(str等),那么is,is not 和 ==,!= 是完全等价的。</li><li>对比的两个变量,指向的是地址可变的类型(list,dict,tuple等),则两者是有区别的。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">a, b = <span class="string">'abc'</span>, <span class="string">'abc'</span></span><br><span class="line">print(a <span class="keyword">is</span> b, a == b) <span class="comment"># True True</span></span><br><span class="line">print(a <span class="keyword">is</span> <span class="keyword">not</span> b, a != b) <span class="comment"># False False</span></span><br><span class="line"></span><br><span class="line">a, b = [<span class="string">'abc'</span>], [<span class="string">'abc'</span>]</span><br><span class="line">print(a <span class="keyword">is</span> b, a == b) <span class="comment"># False True</span></span><br><span class="line">print(a <span class="keyword">is</span> <span class="keyword">not</span> b, a != b) <span class="comment"># True False</span></span><br></pre></td></tr></table></figure><h3 id="运算符的优先级"><a href="#运算符的优先级" class="headerlink" title="运算符的优先级"></a>运算符的优先级</h3>不同运算符的优先度不同,优先级从高到低为:<br><code>**</code> > <code>*, /, %, //</code> > <code>+-</code> > <code>&</code> > <code>>, >=, <, <=</code> > <code>==, !=</code> > <code>=, +=, -=, *=, /=, %=, //=</code> > <code>is, not is</code> > <code>in, not in</code> > <code>and, or, not</code><h2 id="变量与赋值"><a href="#变量与赋值" class="headerlink" title="变量与赋值"></a>变量与赋值</h2></li></ul><ol><li>在使用变量之前,需要对其先赋值。</li><li>变量名可以包括字母、数字、下划线、但变量名不能以数字开头。 </li><li>Python 变量名是大小写敏感的,foo != Foo。</li><li>同一行可以赋值多个变量,如 <code>a, b = 1, 2</code>,<code>a</code> 和 <code>b</code> 分别赋予了 1 和 2。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">a, b = <span class="number">1</span>, <span class="number">2</span></span><br><span class="line">c = a + b</span><br><span class="line">print(c) <span class="comment"># 3</span></span><br></pre></td></tr></table></figure><h2 id="数据类型与转换"><a href="#数据类型与转换" class="headerlink" title="数据类型与转换"></a>数据类型与转换</h2>常见的数据类型有:<code>str</code>(字符),<code>int</code>(整型),<code>float</code>(浮点型)和 <code>bool</code>(布尔型)。布尔变量只能是 <code>True</code> 和 <code>False</code>。除了直接给变量赋予 <code>True</code> 和 <code>False</code> 也可以用 <code>bool(X)</code> 让 Python 自行判断,<code>X</code> 可以是一个值(整型,浮点型或布尔型)也可以是一个容器(字符串,列表,元组,集合或字典),判断依据是:</li><li>对于数值,0 为 <code>False</code>,非 0 为 <code>True</code>;</li><li>对于容器,空容器为 <code>False</code>,非空容器为 <code>True</code>。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">print(bool(<span class="literal">False</span>), bool(<span class="literal">True</span>)) <span class="comment"># False, True</span></span><br><span class="line">print(bool(<span class="number">0</span>), bool(<span class="number">1</span>)) <span class="comment"># False, True</span></span><br><span class="line">print(bool(<span class="number">0.0</span>), bool(<span class="number">1.5</span>)) <span class="comment"># False, True</span></span><br><span class="line"></span><br><span class="line">print(bool(<span class="string">''</span>), bool(<span class="string">'abc'</span>)) <span class="comment"># False, True</span></span><br><span class="line">print(bool([]), bool([<span class="number">1</span>, <span class="number">2</span>])) <span class="comment"># False, True</span></span><br><span class="line">print(bool(()), bool((<span class="number">1</span>, <span class="number">2</span>))) <span class="comment"># False, True</span></span><br><span class="line">print(bool({}), bool({<span class="number">1</span>, <span class="number">2</span>})) <span class="comment"># False, True</span></span><br><span class="line">print(bool({}), bool({<span class="string">'a'</span>:<span class="number">1</span>})) <span class="comment"># False, True</span></span><br></pre></td></tr></table></figure></li></ol><p><strong>获取类型信息</strong>:</p><ul><li><code>type(X)</code>:返回 X 的类型信息</li><li><code>isinstance(var, type)</code>:返回 <code>var</code> 是不是 <code>type</code><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">print(type(<span class="number">1</span>)) <span class="comment"># <class 'int'></span></span><br><span class="line">print(type([<span class="number">1</span>, <span class="number">2</span>])) <span class="comment"># <class 'list'></span></span><br><span class="line"></span><br><span class="line">print(isinstance(<span class="number">1</span>, int)) <span class="comment"># True</span></span><br><span class="line">print(isinstance({<span class="number">1</span>, <span class="number">2</span>}), set) <span class="comment"># True</span></span><br></pre></td></tr></table></figure><h2 id="print-函数"><a href="#print-函数" class="headerlink" title="print() 函数"></a><code>print()</code> 函数</h2>函数的语法为<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">print(*objects, sep=<span class="string">' '</span>, end=<span class="string">'\n'</span>, file=sys.stdout, flush=<span class="literal">False</span>)</span><br></pre></td></tr></table></figure></li></ul><ol><li>将对象以字符串表示的方式格式化输出到流文件对象file里。其中所有非关键字参数都按 str() 方式进行转换为字符 串输出;</li><li>关键字参数 sep 是实现分隔符,比如多个参数输出时想要输出中间的分隔字符;</li><li>关键字参数 end 是输出结束时的字符,默认是换行符 \n ;</li><li>关键字参数 file 是定义流输出的文件,可以是标准的系统输出 sys.stdout ,也可以重定义为别的文件;</li><li>关键字参数 flush 是立即把内容输出到流文件,不作缓存。</li></ol><p>常用的变量为 <code>sep</code> 和 <code>end</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 修改 sep 变量</span></span><br><span class="line">print(<span class="string">'apple'</span>, <span class="string">'banana'</span>) <span class="comment"># sep = ' '</span></span><br><span class="line"><span class="comment"># apple banana</span></span><br><span class="line">print(<span class="string">'apple'</span>, <span class="string">'banana'</span>, sep=<span class="string">'&'</span>) <span class="comment"># apple 和 banana 之间用 & 分隔</span></span><br><span class="line"><span class="comment"># apple&banana</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 修改 end 变量</span></span><br><span class="line">fruits = [<span class="string">'apple'</span>, <span class="string">'banana'</span>]</span><br><span class="line">print(<span class="string">"This is printed with default end"</span>)</span><br><span class="line"><span class="keyword">for</span> item <span class="keyword">in</span> fruits:</span><br><span class="line"> print(item)</span><br><span class="line"><span class="comment"># This is printed with default end</span></span><br><span class="line"><span class="comment"># apple</span></span><br><span class="line"><span class="comment"># fruit</span></span><br><span class="line"></span><br><span class="line">print(<span class="string">"This is printed with 'end='&''"</span>)</span><br><span class="line"><span class="keyword">for</span> item <span class="keyword">in</span> fruits:</span><br><span class="line"> print(item, end=<span class="string">'&'</span>)</span><br><span class="line"><span class="comment"># This is printed with 'end='&''</span></span><br><span class="line">apple&banana&</span><br></pre></td></tr></table></figure><h1 id="条件语句"><a href="#条件语句" class="headerlink" title="条件语句"></a>条件语句</h1><h2 id="if-语句"><a href="#if-语句" class="headerlink" title="if 语句"></a><code>if</code> 语句</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> expression:</span><br><span class="line"> exp_true_action</span><br></pre></td></tr></table></figure><p>如果 <code>expression</code> 为真,则执行 <code>exp_true_action</code>;否则不会执行。<code>expression</code> 可为多重条件判断。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> <span class="number">2</span> > <span class="number">1</span> <span class="keyword">and</span> <span class="keyword">not</span> <span class="number">2</span> > <span class="number">3</span>:</span><br><span class="line"> print(<span class="string">'Correct!'</span>)</span><br><span class="line"><span class="comment"># Correct!</span></span><br></pre></td></tr></table></figure><h2 id="if-else-语句"><a href="#if-else-语句" class="headerlink" title="if-else 语句"></a><code>if-else</code> 语句</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> expression:</span><br><span class="line"> exp_true_action</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> exp_false_action</span><br></pre></td></tr></table></figure><p>如果 <code>expression</code> 为真,则执行 <code>exp_true_action</code>;否则执行 <code>exp_false_action</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">color = <span class="string">'red'</span></span><br><span class="line"><span class="keyword">if</span> color == <span class="string">'blue'</span>:</span><br><span class="line"> print(<span class="string">'Color is blue.'</span>)</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">'Color is not blue.'</span>)</span><br><span class="line"><span class="comment"># Color is not blue.</span></span><br></pre></td></tr></table></figure><h2 id="if-elif-else-语句"><a href="#if-elif-else-语句" class="headerlink" title="if-elif-else 语句"></a><code>if-elif-else</code> 语句</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">if</span> expression1:</span><br><span class="line"> exp1_true_action</span><br><span class="line"><span class="keyword">elif</span> expression2:</span><br><span class="line"> exp2_true_action</span><br><span class="line"> .</span><br><span class="line"> .</span><br><span class="line"><span class="keyword">elif</span> expressionN:</span><br><span class="line"> expN_true_action</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> exp_false_action</span><br></pre></td></tr></table></figure><p>分别判断哪一个 <code>expression</code> 为真,哪个为真就执行哪个动作;全为假则执行 <code>exp_false_action</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">number = <span class="number">5</span></span><br><span class="line"><span class="keyword">if</span> number < <span class="number">3</span>:</span><br><span class="line"> print(<span class="string">'Number is smaller than 3.'</span>)</span><br><span class="line"><span class="keyword">elif</span> number < <span class="number">7</span>:</span><br><span class="line"> print(<span class="string">'Number is between 3 and 7.'</span>)</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">'Number is greater than 7.'</span>)</span><br><span class="line"><span class="comment"># Number is between 3 and 7.</span></span><br></pre></td></tr></table></figure><h2 id="assert-语句"><a href="#assert-语句" class="headerlink" title="assert 语句"></a><code>assert</code> 语句</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">assert</span> expression, text</span><br></pre></td></tr></table></figure><p>如果 <code>expression</code> 为假,则中断程序运行,抛出 <code>AssertionError</code> 异常,异常信息为 <code>text</code>。</p><h1 id="循环语句"><a href="#循环语句" class="headerlink" title="循环语句"></a>循环语句</h1><h2 id="while-循环"><a href="#while-循环" class="headerlink" title="while 循环"></a><code>while</code> 循环</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">while</span> expression:</span><br><span class="line"> action</span><br></pre></td></tr></table></figure><p>如果 <code>expression</code> 为真,则执行 <code>action</code>,然后再判断 <code>expression</code> 是否为真,若还为真则再执行 <code>action</code>,再判断 <code>expression</code> 是否为真,…,直到 <code>expression</code> 为假,循环结束。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">i = <span class="number">0</span></span><br><span class="line"><span class="keyword">while</span> i < <span class="number">2</span>:</span><br><span class="line"> print(i)</span><br><span class="line"> i += <span class="number">1</span></span><br><span class="line"><span class="comment"># 0</span></span><br><span class="line"><span class="comment"># 1</span></span><br></pre></td></tr></table></figure><h2 id="while-else-循环"><a href="#while-else-循环" class="headerlink" title="while-else 循环"></a><code>while-else</code> 循环</h2><p>当 <code>while</code> 循环正常执行完的情况下,执行 <code>else</code> 输出,如果 <code>while</code> 循环中执行了跳出循环的语句,比如 <code>break</code>,将不执行 <code>else</code> 代码块的内容。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line">count = <span class="number">0</span></span><br><span class="line"><span class="keyword">while</span> count <<span class="number">5</span>:</span><br><span class="line"> print(<span class="string">f'<span class="subst">{count}</span> is less than 5'</span>)</span><br><span class="line"> count += <span class="number">1</span></span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">f'<span class="subst">{count}</span> is not less than 5'</span>)</span><br><span class="line"><span class="comment"># count is less than 5</span></span><br><span class="line"><span class="comment"># count is less than 5</span></span><br><span class="line"><span class="comment"># count is less than 5</span></span><br><span class="line"><span class="comment"># count is less than 5</span></span><br><span class="line"><span class="comment"># count is less than 5</span></span><br><span class="line"><span class="comment"># count is not less than 5</span></span><br><span class="line">count = <span class="number">0</span></span><br><span class="line"><span class="keyword">while</span> count < <span class="number">5</span>:</span><br><span class="line"> print(<span class="string">f'<span class="subst">{count}</span> is less than 5'</span>)</span><br><span class="line"> count = <span class="number">6</span></span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">f'count is not less than 5'</span>)</span><br><span class="line"><span class="comment"># 0 is less than 5</span></span><br></pre></td></tr></table></figure><h2 id="for-循环"><a href="#for-循环" class="headerlink" title="for 循环"></a><code>for</code> 循环</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> 迭代变量 <span class="keyword">in</span> 迭代器:</span><br><span class="line"> action</span><br></pre></td></tr></table></figure><p><code>for</code> 循环是迭代循环,在 Python 中相当于一个通用的序列迭代器,可以遍历任何有序序列,如 <code>str、list、tuple</code> 等,也可以遍历任何可迭代对象,如 <code>dict</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> s <span class="keyword">in</span> <span class="string">'abc'</span>:</span><br><span class="line"> print(s)</span><br><span class="line"><span class="comment"># a</span></span><br><span class="line"><span class="comment"># b</span></span><br><span class="line"><span class="comment"># c</span></span><br></pre></td></tr></table></figure><h2 id="for-else-循环"><a href="#for-else-循环" class="headerlink" title="for-else 循环"></a><code>for-else</code> 循环</h2><p>当 <code>for</code> 循环正常执行完的情况下,执行 <code>else</code> 输出,如果 <code>for</code> 循环中执行了跳出循环的语句,比如 <code>break</code>,将不执行 <code>else</code> 代码块的内容,与 <code>while - else</code> 语句一样。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">2</span>):</span><br><span class="line"> print(i)</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">'done'</span>)</span><br><span class="line"><span class="comment"># 0</span></span><br><span class="line"><span class="comment"># 1</span></span><br><span class="line"><span class="comment"># done</span></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">2</span>):</span><br><span class="line"> <span class="keyword">if</span> i % <span class="number">2</span> == <span class="number">1</span>:</span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"> print(i)</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">'done'</span>)</span><br><span class="line"><span class="comment"># 0</span></span><br></pre></td></tr></table></figure><h2 id="range-函数"><a href="#range-函数" class="headerlink" title="range 函数"></a><code>range</code> 函数</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">range([start,] stop[, step=<span class="number">1</span>])</span><br></pre></td></tr></table></figure><p><code>range</code> 这个内置函数的作用是生成一个从 <code>start</code> 参数的值开始到 <code>stop</code>参数的值结束的数字序列,该序列包含 <code>start</code> 的值但不包含 <code>stop</code> 的值。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">1</span>, <span class="number">5</span>, <span class="number">2</span>):</span><br><span class="line"> print(i)</span><br><span class="line"><span class="comment"># 1</span></span><br><span class="line"><span class="comment"># 3</span></span><br></pre></td></tr></table></figure><h2 id="enumerate-函数"><a href="#enumerate-函数" class="headerlink" title="enumerate 函数"></a><code>enumerate</code> 函数</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">enumerate(sequence, [start=<span class="number">0</span>])</span><br></pre></td></tr></table></figure><p>返回枚举对象,可与 <code>for</code> 循环连用。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line">letters = [<span class="string">'a'</span>, <span class="string">'b'</span>, <span class="string">'c'</span>, <span class="string">'d'</span>]</span><br><span class="line">lst = list(enumerate(letters))</span><br><span class="line">print(lst)</span><br><span class="line"><span class="comment"># [(0, 'a'), (1, 'b'), (2, 'c'), (3, 'd')]</span></span><br><span class="line"><span class="keyword">for</span> idx, letter <span class="keyword">in</span> enumerate(letters, <span class="number">1</span>):</span><br><span class="line"> print(idx, letter)</span><br><span class="line"><span class="comment"># 1 a</span></span><br><span class="line"><span class="comment"># 2 b</span></span><br><span class="line"><span class="comment"># 3 c</span></span><br><span class="line"><span class="comment"># 4 d</span></span><br></pre></td></tr></table></figure><h2 id="break-语句"><a href="#break-语句" class="headerlink" title="break 语句"></a><code>break</code> 语句</h2><p><code>break</code> 语句可以跳出当前所在层的循环,例子见上。</p><h2 id="continue-语句"><a href="#continue-语句" class="headerlink" title="continue 语句"></a><code>continue</code> 语句</h2><p><code>continue</code> 终止本轮循环并开始下一轮循环。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">3</span>):</span><br><span class="line"> <span class="keyword">if</span> <span class="keyword">not</span> i % <span class="number">2</span>:</span><br><span class="line"> <span class="keyword">continue</span></span><br><span class="line"> print(i)</span><br><span class="line"><span class="comment"># 1</span></span><br></pre></td></tr></table></figure><h2 id="pass-语句"><a href="#pass-语句" class="headerlink" title="pass 语句"></a><code>pass</code> 语句</h2><p><code>pass</code> 语句的意思是“不做任何事”,即不做任何操作,只起到占位的作用,其作用是为了保持程序结构的完整性。尽管 <code>pass</code> 语句不做任何操作,但如果暂时不确定要在一个位置放上什么样的代码,可以先放置一个 <code>pass</code> 语句,让代码可以正常运行。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">3</span>):</span><br><span class="line"> <span class="keyword">if</span> <span class="keyword">not</span> i % <span class="number">2</span>:</span><br><span class="line"> <span class="keyword">pass</span></span><br><span class="line"> print(i)</span><br><span class="line"><span class="comment"># 0</span></span><br><span class="line"><span class="comment"># 1</span></span><br><span class="line"><span class="comment"># 2</span></span><br></pre></td></tr></table></figure><h2 id="解析式"><a href="#解析式" class="headerlink" title="解析式"></a>解析式</h2><h3 id="列表解析式"><a href="#列表解析式" class="headerlink" title="列表解析式"></a>列表解析式</h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">[expr <span class="keyword">for</span> val <span class="keyword">in</span> iterable [<span class="keyword">if</span> condition]]</span><br></pre></td></tr></table></figure><p>返回一个根据解析式条件创建的列表。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">[i <span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">5</span>) <span class="keyword">if</span> i % <span class="number">2</span> == <span class="number">1</span>]</span><br><span class="line"><span class="comment"># [1 ,3]</span></span><br></pre></td></tr></table></figure><h3 id="字典解析式"><a href="#字典解析式" class="headerlink" title="字典解析式"></a>字典解析式</h3><p>类似列表解析式,只是变量为键值对。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">{i:j <span class="keyword">for</span> i, j <span class="keyword">in</span> enumerate(<span class="string">'abcde'</span>)}</span><br><span class="line"><span class="comment"># {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e'}</span></span><br></pre></td></tr></table></figure><h3 id="集合解析式"><a href="#集合解析式" class="headerlink" title="集合解析式"></a>集合解析式</h3><p>类似列表解析式。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">{i <span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">5</span>) <span class="keyword">if</span> i % <span class="number">2</span> == <span class="number">1</span>}</span><br><span class="line"><span class="comment"># {1, 3}</span></span><br></pre></td></tr></table></figure><h3 id="生成式"><a href="#生成式" class="headerlink" title="生成式"></a>生成式</h3><p>类似列表解析式,只是返回一个迭代器对象。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">(i <span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">3</span>) <span class="keyword">if</span> i % <span class="number">2</span> == <span class="number">1</span>)</span><br><span class="line"><span class="comment"># <generator object <genexpr> at 0x7fc3aad2b2d0></span></span><br></pre></td></tr></table></figure><h2 id="iter-与-next"><a href="#iter-与-next" class="headerlink" title="iter 与 next"></a><code>iter</code> 与 <code>next</code></h2><p><code>iter</code> 将一个可迭代对象转换为一个迭代器,可以使用 <code>next</code> 进行遍历。遍历结束以后若继续迭代,则抛出 <code>StopIteration</code> 异常。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line">iterator = iter(list(range(<span class="number">3</span>)))</span><br><span class="line">print(iterator)</span><br><span class="line"><list_iterator object at <span class="number">0x7fc3aa756990</span>></span><br><span class="line">next(iterator)</span><br><span class="line"><span class="comment"># 0</span></span><br><span class="line">next(iterator)</span><br><span class="line"><span class="comment"># 1</span></span><br><span class="line">next(iterator)</span><br><span class="line"><span class="comment"># 2</span></span><br><span class="line">next(iterator)</span><br><span class="line">---------------------------------------------------------------------------</span><br><span class="line">StopIteration Traceback (most recent call last)</span><br><span class="line"><ipython-input<span class="number">-74</span><span class="number">-4</span>ce711c44abc> <span class="keyword">in</span> <module></span><br><span class="line">----> 1 next(iterator)</span><br><span class="line"></span><br><span class="line">StopIteration:</span><br></pre></td></tr></table></figure><h1 id="异常处理"><a href="#异常处理" class="headerlink" title="异常处理"></a>异常处理</h1><h2 id="Python-常用标准异常"><a href="#Python-常用标准异常" class="headerlink" title="Python 常用标准异常"></a>Python 常用标准异常</h2><ul><li>AssertionError:<code>Assertion</code> 的条件为 <code>False</code></li><li>AttributeError:尝试访问未知的对象属性</li><li>IndexError:索引超出序列的范围</li><li>KeyError:字典中查找一个不存在的关键字</li><li>NameError:尝试访问一个不存在的变量</li><li>TypeError:不同类型间的无效操作</li><li>ValueError:传入无效的参数<h2 id="try-except-语句"><a href="#try-except-语句" class="headerlink" title="try-except 语句"></a><code>try-except</code> 语句</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">try</span>:</span><br><span class="line"> expression</span><br><span class="line"><span class="keyword">except</span> Exception[ <span class="keyword">as</span> reason]:</span><br><span class="line"> action</span><br></pre></td></tr></table></figure>try 语句按照如下方式工作:</li><li>首先,执行 <code>try</code> 子句(在关键字 <code>try</code> 和关键字 <code>except</code> 之间的语句)</li><li>如果没有异常发生,忽略 <code>except</code> 子句,<code>try</code> 子句执行后结束。</li><li>如果在执行 <code>try</code> 子句的过程中发生了异常,那么 <code>try</code> 子句余下的部分将被忽略。如果异常的类型和 <code>except</code> 之后的名称相符,那么对应的 <code>except</code> 子句将被执行。最后执行<code>try - except</code> 语句之后的代码。</li><li>如果一个异常没有与任何的 <code>except</code> 匹配,那么这个异常将会传递给上层的 <code>try</code> 中。</li><li>一个 <code>try</code> 语句可能包含多个 <code>except</code> 子句,分别来处理不同的特定的异常。最多只有一个分支会被执行。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">try</span>:</span><br><span class="line"> int(<span class="string">"abc"</span>)</span><br><span class="line"> s = <span class="number">1</span> + <span class="string">'1'</span></span><br><span class="line"> f = open(<span class="string">'test.txt'</span>)</span><br><span class="line"> print(f.read())</span><br><span class="line"> f.close()</span><br><span class="line"><span class="keyword">except</span> OSError <span class="keyword">as</span> error:</span><br><span class="line"> print(<span class="string">'打开文件出错\n原因是:'</span> + str(error))</span><br><span class="line"><span class="keyword">except</span> TypeError <span class="keyword">as</span> error:</span><br><span class="line"> print(<span class="string">'类型出错\n原因是:'</span> + str(error))</span><br><span class="line"><span class="keyword">except</span> ValueError <span class="keyword">as</span> error:</span><br><span class="line"> print(<span class="string">'数值出错\n原因是:'</span> + str(error))</span><br><span class="line"></span><br><span class="line"><span class="comment"># 数值出错</span></span><br><span class="line"><span class="comment"># 原因是:invalid literal for int() with base 10: 'abc'</span></span><br></pre></td></tr></table></figure><h2 id="try-except-finally-语句"><a href="#try-except-finally-语句" class="headerlink" title="try-except-finally 语句"></a><code>try-except-finally</code> 语句</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">try</span>:</span><br><span class="line"> 检测范围</span><br><span class="line"><span class="keyword">except</span> Exception[<span class="keyword">as</span> reason]:</span><br><span class="line"> 出现异常后的处理代码</span><br><span class="line"><span class="keyword">finally</span>:</span><br><span class="line"> 无论如何都会被执行的代码</span><br></pre></td></tr></table></figure>不管<code>try</code>子句里面有没有发生异常,<code>finally</code>子句都会执行。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">divide</span><span class="params">(x, y)</span>:</span></span><br><span class="line"> <span class="keyword">try</span>:</span><br><span class="line"> result = x / y</span><br><span class="line"> print(<span class="string">"result is"</span>, result)</span><br><span class="line"> <span class="keyword">except</span> ZeroDivisionError:</span><br><span class="line"> print(<span class="string">"division by zero!"</span>)</span><br><span class="line"> <span class="keyword">finally</span>:</span><br><span class="line"> print(<span class="string">"executing finally clause"</span>)</span><br><span class="line"></span><br><span class="line"></span><br><span class="line">divide(<span class="number">2</span>, <span class="number">1</span>)</span><br><span class="line"><span class="comment"># result is 2.0</span></span><br><span class="line"><span class="comment"># executing finally clause</span></span><br><span class="line">divide(<span class="number">2</span>, <span class="number">0</span>)</span><br><span class="line"><span class="comment"># division by zero!</span></span><br><span class="line"><span class="comment"># executing finally clause</span></span><br><span class="line">divide(<span class="string">"2"</span>, <span class="string">"1"</span>)</span><br><span class="line"><span class="comment"># executing finally clause</span></span><br><span class="line">---------------------------------------------------------------------------</span><br><span class="line">TypeError Traceback (most recent call last)</span><br><span class="line"><ipython-input<span class="number">-79</span><span class="number">-16805</span>cf48925> <span class="keyword">in</span> <module></span><br><span class="line"> <span class="number">15</span> <span class="comment"># division by zero!</span></span><br><span class="line"> <span class="number">16</span> <span class="comment"># executing finally clause</span></span><br><span class="line">---> 17 divide("2", "1")</span><br><span class="line"> <span class="number">18</span> <span class="comment"># executing finally clause</span></span><br><span class="line"> <span class="number">19</span> <span class="comment"># TypeError: unsupported operand type(s) for /: 'str' and 'str'</span></span><br><span class="line"></span><br><span class="line"><ipython-input<span class="number">-79</span><span class="number">-16805</span>cf48925> <span class="keyword">in</span> divide(x, y)</span><br><span class="line"> <span class="number">1</span> <span class="function"><span class="keyword">def</span> <span class="title">divide</span><span class="params">(x, y)</span>:</span></span><br><span class="line"> <span class="number">2</span> <span class="keyword">try</span>:</span><br><span class="line">----> 3 result = x / y</span><br><span class="line"> <span class="number">4</span> print(<span class="string">"result is"</span>, result)</span><br><span class="line"> <span class="number">5</span> <span class="keyword">except</span> ZeroDivisionError:</span><br><span class="line"></span><br><span class="line">TypeError: unsupported operand type(s) <span class="keyword">for</span> /: <span class="string">'str'</span> <span class="keyword">and</span> <span class="string">'str'</span></span><br></pre></td></tr></table></figure><h2 id="try-except-else-语句"><a href="#try-except-else-语句" class="headerlink" title="try-except-else 语句"></a><code>try-except-else</code> 语句</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">try</span>:</span><br><span class="line"> 检测范围</span><br><span class="line"><span class="keyword">except</span>:</span><br><span class="line"> 出现异常后的处理代码</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> 如果没有异常执行这块代码</span><br></pre></td></tr></table></figure>如果 <code>except</code>语句没有执行,则继续执行 <code>else</code> 语句;若执行则跳过 <code>else</code> 语句。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line">dict1 = {<span class="string">'a'</span>: <span class="number">1</span>, <span class="string">'b'</span>: <span class="number">2</span>, <span class="string">'v'</span>: <span class="number">22</span>}</span><br><span class="line"><span class="keyword">try</span>:</span><br><span class="line"> x = dict1[<span class="string">'y'</span>]</span><br><span class="line"><span class="keyword">except</span> KeyError:</span><br><span class="line"> print(<span class="string">'键错误'</span>)</span><br><span class="line"><span class="keyword">except</span> LookupError:</span><br><span class="line"> print(<span class="string">'查询错误'</span>)</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> print(x)</span><br><span class="line"><span class="comment"># 键错误</span></span><br><span class="line">dict1 = {<span class="string">'a'</span>: <span class="number">1</span>, <span class="string">'b'</span>: <span class="number">2</span>, <span class="string">'v'</span>: <span class="number">22</span>}</span><br><span class="line"><span class="keyword">try</span>:</span><br><span class="line"> x = dict1[<span class="string">'a'</span>]</span><br><span class="line"><span class="keyword">except</span> KeyError:</span><br><span class="line"> print(<span class="string">'键错误'</span>)</span><br><span class="line"><span class="keyword">except</span> LookupError:</span><br><span class="line"> print(<span class="string">'查询错误'</span>)</span><br><span class="line"><span class="keyword">else</span>:</span><br><span class="line"> print(x)</span><br><span class="line"><span class="comment"># 1</span></span><br></pre></td></tr></table></figure><h2 id="raise-语句"><a href="#raise-语句" class="headerlink" title="raise 语句"></a><code>raise</code> 语句</h2>抛出一个错误异常。<figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">raise NameError()</span><br><span class="line">---------------------------------------------------------------------------</span><br><span class="line">NameError Traceback (most recent call last)</span><br><span class="line"><ipython-input-80-1d90210dd9ab> in <module></span><br><span class="line">----> 1 raise NameError()</span><br><span class="line"></span><br><span class="line">NameError:</span><br></pre></td></tr></table></figure></li></ul>]]></content>
<tags>
<tag> 基础 </tag>
</tags>
</entry>
<entry>
<title>[工作站] 我的第一台个人深度学习工作站之配置环境篇</title>
<link href="2020/11/02/%E5%B7%A5%E4%BD%9C%E7%AB%99-%E6%88%91%E7%9A%84%E7%AC%AC%E4%B8%80%E5%8F%B0%E4%B8%AA%E4%BA%BA%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E4%B9%8B%E9%85%8D%E7%BD%AE%E7%8E%AF%E5%A2%83%E7%AF%87/"/>
<url>2020/11/02/%E5%B7%A5%E4%BD%9C%E7%AB%99-%E6%88%91%E7%9A%84%E7%AC%AC%E4%B8%80%E5%8F%B0%E4%B8%AA%E4%BA%BA%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E4%B9%8B%E9%85%8D%E7%BD%AE%E7%8E%AF%E5%A2%83%E7%AF%87/</url>
<content type="html"><![CDATA[<p>工作站的配置有两种方式:</p><ul><li>传统方法;</li><li>Docker。</li></ul><p>理论上 Docker 更好一点,因为 <code>nvidia-docker2</code> 是一个已经配置好的环境,不需要手动安装 CUDA 和 cuDNN,而且不需要了可以删除,更新也很方便。然而本文中将采用传统方法对工作站进行配置。使用 Docker 可以参考这篇<a href="https://cnvrg.io/how-to-setup-docker-and-nvidia-docker-2-0-on-ubuntu-18-04/" target="_blank" rel="noopener" title="How to setup Docker and Nvidia-Docker 2.0 on Ubuntu 18.04">文章</a>。</p><a id="more"></a><h2 id="安装系统"><a href="#安装系统" class="headerlink" title="安装系统"></a><a href="https://www.sysgeek.cn/install-ubuntu-20-04-lts-desktop/" target="_blank" rel="noopener" title="Ubuntu 20.04 LTS 桌面版详细安装指南">安装系统</a></h2><p>我选择的是 Ubuntu 20.04。我在 MacOS 上制作引导盘,方法在<a href="https://ubuntu.com/tutorials/create-a-usb-stick-on-macos#1-overview" target="_blank" rel="noopener" title="Create a bootable USB stick on macOS">这里</a>。一路安装都很简单,有几个建议:</p><ol><li>在选择安装模式的时候选择“Minimal Installation”,不安装更新和第三方软件,一会系统安装好了以后手动更新。<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkFUQhwosjXq2DdhTRka7burQzvdibLullKy8IIuv2m5KUx7icmaxjh5G4NDZOOTEK22NnzC6m3vIq5w/0?wx_fmt=png" alt="两个确认框都取消"></li><li>在分区时选择默认分区就好了,不需要 LVM,不要选择 ZFS。</li><li>在设置用户的时候选择 “Log in automatically”。</li></ol><h2 id="系统更新"><a href="#系统更新" class="headerlink" title="系统更新"></a>系统更新</h2><p>安装好了 Ubuntu 并设置好网络以后,我们要做的第一件事就是更新系统、安装基础组件。工作站配置好了以后就不要轻易更新了,以免破坏环境。</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">>>> sudo apt-get update && sudo apt-get upgrade && sudo apt-get autoremove</span><br><span class="line">>>> sudo apt install build-essential vim git curl wget make</span><br></pre></td></tr></table></figure><h2 id="设置-SSH"><a href="#设置-SSH" class="headerlink" title="设置 SSH"></a><a href="https://www.linuxbabe.com/linux-server/setup-passwordless-ssh-login" target="_blank" rel="noopener" title="2 Simple Steps to Set up Passwordless SSH Login on Ubuntu">设置 SSH</a></h2><ol><li>安装 SSH:<code>sudo apt install Openssh-server</code></li><li>启动 SSH:<code>sudo /etc/init.d/ssh start</code></li><li>开机自动启动 SSH:<code>sudo systemctl enable ssh</code></li><li>在客户端生成 SSH 秘钥<br><code>ssh-keygen</code> 或 强秘钥 <code>ssh-keygen -t rsa -b 4096</code></li><li>在客户端上传秘钥至服务器:<code>ssh-copy-id remote-user@server-ip</code></li><li>关闭 SSH 密码登录 <code>sudo vim /etc/ssh/sshd_config</code><br>找到 <code>#PasswordAuthentication yes</code> 修改为 PasswordAuthentication no<br>找到 <code>ChallengeResponseAuthentication</code> 修改为 <code>ChallengeResponseAuthentication no</code></li><li>重启 SSH 服务<br><code>sudo service ssh restart</code> 或 <code>sudo systemctl restart ssh</code></li><li>(可选)备份 SSH key<br><code>cp ~/.ssh/id_rsa* /path/to/safe/location/</code></li></ol><p>设置好 SSH 以后,我们可以在个人电脑上登录工作站。</p><h2 id="设置防火墙"><a href="#设置防火墙" class="headerlink" title="设置防火墙"></a><a href="https://www.digitalocean.com/community/tutorials/initial-server-setup-with-ubuntu-20-04" target="_blank" rel="noopener" title="Initial Server Setup with Ubuntu 20.04">设置防火墙</a></h2><p>下面来设置一下防火墙,防止有坏人来捣乱。</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">>>> sudo ufw allow OpenSSH</span><br><span class="line">>>> sudo ufw allow 22 <span class="comment"># 开放端口 22</span></span><br><span class="line">>>> ufw <span class="built_in">enable</span></span><br></pre></td></tr></table></figure><h2 id="安装显卡驱动"><a href="#安装显卡驱动" class="headerlink" title="安装显卡驱动"></a><a href="https://linuxconfig.org/how-to-install-the-nvidia-drivers-on-ubuntu-20-04-focal-fossa-linux" target="_blank" rel="noopener" title="How to install the NVIDIA drivers on Ubuntu 20.04 Focal Fossa Linux">安装显卡驱动</a></h2><p>Ubuntu 提供了一个很简单的命令来安装显卡驱动。</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">>>> sudo ubuntu-drivers autoinstall</span><br><span class="line">>>> sudo reboot <span class="comment"># 重启工作站</span></span><br></pre></td></tr></table></figure><h2 id="安装-cuda-11-1"><a href="#安装-cuda-11-1" class="headerlink" title="安装 cuda 11.1"></a><a href="https://zhuanlan.zhihu.com/p/78002221" target="_blank" rel="noopener" title="Ubuntu下CUDA,CUDNN和Tensorflow配置">安装 cuda 11.1</a></h2><p>这里安装最新的 CUDA Toolkit 11.1。记得安装的时候把”安装驱动程序“取消。</p><ol><li>安装 cuda<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">>>> wget https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda_11.1.1_455.32.00_linux.run</span><br><span class="line">>>> sudo sh cuda_11.1.1_455.32.00_linux.run</span><br></pre></td></tr></table></figure></li><li>将以下内容加进 /etc/profile:<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="built_in">export</span> PATH=/usr/<span class="built_in">local</span>/cuda-11.1/bin:<span class="variable">$PATH</span></span><br><span class="line"><span class="built_in">export</span> LD_LIBRARY_PATH=/usr/<span class="built_in">local</span>/cuda-11.1/lib64<span class="variable">$LD_LIBRARY_PATH</span></span><br></pre></td></tr></table></figure></li><li>重启电脑 <code>sudo reboot</code></li><li>检查cuda 版本<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">>>> nvcc -V</span><br><span class="line">nvcc: NVIDIA (R) Cuda compiler driver</span><br><span class="line">Copyright (c) 2005-2020 NVIDIA Corporation</span><br><span class="line">Built on Mon_Oct_12_20:09:46_PDT_2020</span><br><span class="line">Cuda compilation tools, release 11.1, V11.1.105</span><br><span class="line">Build cuda_11.1.TC455_06.29190527_0</span><br></pre></td></tr></table></figure></li><li>测试<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line">>>> cuda-install-samples-11.1.sh ~</span><br><span class="line">Copying samples to /home/nlp/NVIDIA_CUDA-11.1_Samples now...</span><br><span class="line">Finished copying samples.</span><br><span class="line">>>> <span class="built_in">cd</span> /home/nlp/NVIDIA_CUDA-11.1_Samples/</span><br><span class="line">>>> make <span class="comment"># 耗时约 20 分钟</span></span><br><span class="line">>>> ./1_Utilities/deviceQuery/deviceQuery <span class="comment"># 如果通过测试会显示如下信息</span></span><br><span class="line">./1_Utilities/deviceQuery/deviceQuery Starting...</span><br><span class="line"></span><br><span class="line"> CUDA Device Query (Runtime API) version (CUDART static linking)</span><br><span class="line"></span><br><span class="line">Detected 2 CUDA Capable device(s)</span><br><span class="line">...</span><br><span class="line">deviceQuery, CUDA Driver = CUDART, CUDA Driver Version = 11.0, CUDA Runtime Version = 11.1, NumDevs = 2</span><br><span class="line">Result = PASS</span><br></pre></td></tr></table></figure></li></ol><h2 id="安装-cuDNN-8-0-4"><a href="#安装-cuDNN-8-0-4" class="headerlink" title="安装 cuDNN 8.0.4"></a><a href="https://zhuanlan.zhihu.com/p/143429249" target="_blank" rel="noopener" title="简易记录:安装GPU驱动,CUDA和cuDNN">安装 cuDNN 8.0.4</a></h2><p>安装以前要先<a href="https://developer.nvidia.com/cudnn-download-survey" target="_blank" rel="noopener" title="NVIDIA Developer Program Membership Required">注册新用户</a>。</p><ol><li>在注册后,到 <code>https://developer.nvidia.com/rdp/cudnn-download</code> 下载 <code>libcudnn8_8.0.4.30-1+cuda11.1_amd64.deb</code>,<code>libcudnn8-dev_8.0.4.30-1+cuda11.1_amd64.deb</code> 和 <code>libcudnn8-samples_8.0.4.30-1+cuda11.1_amd64.deb</code>。可以在个人电脑上下载然后 <code>scp</code> 给工作站,下同。</li><li>解包:<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">>>> sudo dpkg -i libcudnn*</span><br></pre></td></tr></table></figure></li><li>检查 cudnn:<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">>>> cp -r /usr/src/cudnn_samples_v8/ <span class="variable">$HOME</span></span><br><span class="line">>>> <span class="built_in">cd</span> <span class="variable">$HOME</span>/cudnn_samples_v8/mnistCUDNN</span><br><span class="line">>>> make clean && make</span><br><span class="line">...</span><br><span class="line">>>> ./mnistCUDNN</span><br><span class="line">...</span><br><span class="line">Test passed!</span><br></pre></td></tr></table></figure><h2 id="安装-Miniconda"><a href="#安装-Miniconda" class="headerlink" title="安装 Miniconda"></a>安装 Miniconda</h2>没必要安装 Anaconda(用不到 GUI),安装 Miniconda 就可以了。</li><li>到 <code>https://docs.conda.io/en/latest/miniconda.html#linux-installers</code> 下载 Miniconda3 Linux 64-bit。运行<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">>>> bash Miniconda3-latest-Linux-x86_64.sh</span><br><span class="line">>>> <span class="built_in">source</span> ~/.bashrc 更新 bash 环境</span><br><span class="line">>>> conda update conda</span><br></pre></td></tr></table></figure></li><li>安装虚拟环境<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">>>> conda create --name nlp python=3.8</span><br></pre></td></tr></table></figure></li></ol><h2 id="配置远程-Jupyter-lab"><a href="#配置远程-Jupyter-lab" class="headerlink" title="配置远程 Jupyter lab"></a><a href="https://blog.csdn.net/starfish55555/article/details/96788672" target="_blank" rel="noopener" title="Ubuntu与windows下配置安装jupyter-notebook以及其开机自启、后台运行与远程访问">配置远程 Jupyter lab</a><a href="https://blog.csdn.net/tuzixini/article/details/79105482" target="_blank" rel="noopener" title="服务器 配置 Jupyter notebook 远程访问 (Ubuntu 14.04)"></a><a href="https://zhuanlan.zhihu.com/p/75524289" target="_blank" rel="noopener" title="Jupyter开发环境的进阶配置(Ubuntu 18.04)"></a></h2><p>安装 Jupyter lab 后可以选择 Jupyter lab 还是 Jupyter notebook,方法是在登录的域名后面加 “/lab?” 或者 “/tree?”。</p><ol><li>生成配置文件:<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">>>> pip install jupyterlab nodejs</span><br><span class="line">>>> jupyter lab --generate-config <span class="comment"># 生成配置文件</span></span><br><span class="line">Writing default config to: /home/nlp/.jupyter/jupyter_notebook_config.py</span><br></pre></td></tr></table></figure></li><li>设置登录密码:<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 首先进入 python 命令行</span></span><br><span class="line">>>> python3 </span><br><span class="line"><span class="comment"># 在命令行下输入</span></span><br><span class="line">>>> from notebook.auth import passwd; passwd()</span><br><span class="line"><span class="comment"># 按照提示输入密码,这是 jupyter 的登陆密码</span></span><br></pre></td></tr></table></figure></li><li>设置成功会出现形如下面的哈希(hash)密码, 保存好,下面会用到<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="string">'argon2:$argon2id$v=19$m=10240,t=10,p=8$mYbUFvU1Csiwz3UGlsRwEA$q7r2mSN5RbFwjbhZCew4fg'</span></span><br></pre></td></tr></table></figure></li><li>配置 Jupyter lab:<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">>>> sudo vim /home/nlp/.jupyter/jupyter_notebook_config.py <span class="comment"># 配置 Jupyter lab</span></span><br><span class="line">c.NotebookApp.ip = <span class="string">'*'</span></span><br><span class="line">c.NotebookApp.token = <span class="string">''</span></span><br><span class="line">c.NotebookApp.password = <span class="string">'argon2:$argon2id$v=19$m=10240,t=10,p=8$mYbUFvU1Csiwz3UGlsRwEA$q7r2mSN5RbFwjbhZCew4fg'</span></span><br><span class="line">c.NotebookApp.open_browser = False</span><br><span class="line">c.NotebookApp.notebook_dir = <span class="string">'/home/nlp/Documents'</span> <span class="comment"># 设置默认根目录</span></span><br><span class="line">c.NotebookApp.allow_remote_access = True</span><br><span class="line">c.NotebookApp.port = 8889 <span class="comment"># 设置端口</span></span><br></pre></td></tr></table></figure></li><li>防火墙开放端口<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">>>> sudo ufw allow 8889</span><br></pre></td></tr></table></figure></li><li>注册虚拟环境<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">>>> conda activate nlp</span><br><span class="line">>>> conda install ipykernel notebook</span><br><span class="line">>>> python -m ipykernel install --user --name nlp --display-name <span class="string">"NLP"</span></span><br></pre></td></tr></table></figure>设置以后在 SSH 状态中输入 “jupyter lab” 后在浏览器地址栏里输入 “域名:IP” 就可以启动 Jupyter lab 了。 因为工作站很少关机,在 ssh 以后输入 <code>nohup jupyter lab &</code> 就可以让 Jupyter lab 保持在后台运行。</li></ol><h2 id="设置-Jupyter-notebook-开机自动在后台启动(可选)"><a href="#设置-Jupyter-notebook-开机自动在后台启动(可选)" class="headerlink" title="设置 Jupyter notebook 开机自动在后台启动(可选)"></a><a href="https://medium.com/@datamove/setup-jupyter-notebook-server-to-start-up-on-boot-and-use-different-conda-environments-147b091b9a5f" target="_blank" rel="noopener" title="Setup Jupyter Notebook Server to Start Up on Boot and Use Different Conda Environments.">设置 Jupyter notebook 开机自动在后台启动</a>(可选)</h2><p>这里要注意是开机自动启动的是 Jupyter notebook,设定好以后就不能启动 Jupyter lab 了。我更喜欢 Jupyter lab,所以调试好以后又删掉了。反正工作站不关机,把 Jupyter lab 挂在后台也不麻烦。</p><ol><li>在 terminal 中输入<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">>>> sudo vim /lib/systemd/system/ipython-notebook.service</span><br></pre></td></tr></table></figure></li><li>在编辑器里粘贴如下:<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line">[Unit]</span><br><span class="line"> Description=IPython notebook</span><br><span class="line">[Service]</span><br><span class="line"> Type=simple</span><br><span class="line"> PIDFile=/var/run/ipython-notebook.pid</span><br><span class="line"> <span class="comment"># 环境是 Jupyter 的默认环境,可以在编辑器内更改</span></span><br><span class="line">Environment=<span class="string">"PATH=/home/*用户名*/miniconda3/envs/*kernel 环境名*/bin:/home/*用户名*/miniconda3/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"</span></span><br><span class="line"> ExecStart=/home/*用户名*/miniconda3/envs/*kernel 环境名*/bin/jupyter-notebook --no-browser --notebook-dir=/home/nlp/Documents --NotebookApp.token=*token* --ip=0.0.0.0</span><br><span class="line"> User=*用户名*</span><br><span class="line"> Group=*用户组名*</span><br><span class="line"> WorkingDirectory=/home/*用户名* <span class="comment"># 此处可根据需要自由设定</span></span><br><span class="line">[Install]</span><br><span class="line"> WantedBy=multi-user.target</span><br></pre></td></tr></table></figure></li><li>依次输入<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">>>> sudo systemctl daemon-reload</span><br><span class="line">>>> sudo systemctl <span class="built_in">enable</span> ipython-notebook</span><br><span class="line">Created symlink /etc/systemd/system/multi-user.target.wants/jupyter.service → /lib/systemd/system/jupyter.service.</span><br><span class="line">>>> sudo systemctl start ipython-notebook</span><br></pre></td></tr></table></figure></li><li>验证 Jupyter notebook 是否设置成功,输入<code>sudo systemctl status ipython-notebook</code>,成 成功的话会输出如下类似的信息:<figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line">● ipython-notebook.service - IPython notebook</span><br><span class="line"> Loaded: loaded (/lib/systemd/system/ipython-notebook.service; enabled; vendor preset: enabled)</span><br><span class="line"> Active: active (running) since Mon 2020-11-02 20:53:51 EST; 7s ago</span><br><span class="line"> Main PID: 3838 (jupyter-noteboo)</span><br><span class="line"> Tasks: 1 (<span class="built_in">limit</span>: 19026)</span><br><span class="line"> Memory: 60.1M</span><br><span class="line"> CGroup: /system.slice/ipython-notebook.service</span><br><span class="line"> └─3838 /home/nlp/miniconda3/envs/notebook_env/bin/python /home/nlp/miniconda3/envs/notebook_env/bin/jupyter-notebook --no-browser --notebook-dir=/home/nlp --NotebookApp.token=argon2:<span class="variable">$argon2id</span><span class="variable">$v</span>=1></span><br><span class="line"></span><br><span class="line">Nov 02 20:53:51 WORKSTATION systemd[1]: Started IPython notebook.</span><br><span class="line">Nov 02 20:53:51 WORKSTATION jupyter-notebook[3838]: [I 20:53:51.858 NotebookApp] [nb_conda_kernels] enabled, 3 kernels found</span><br><span class="line">Nov 02 20:53:52 WORKSTATION jupyter-notebook[3838]: [I 20:53:52.036 NotebookApp] Serving notebooks from <span class="built_in">local</span> directory: /home/nlp</span><br><span class="line">Nov 02 20:53:52 WORKSTATION jupyter-notebook[3838]: [I 20:53:52.037 NotebookApp] Jupyter Notebook 6.1.4 is running at:</span><br><span class="line">Nov 02 20:53:52 WORKSTATION jupyter-notebook[3838]: [I 20:53:52.037 NotebookApp] http://WORKSTATION:8889/?token=...</span><br><span class="line">Nov 02 20:53:52 WORKSTATION jupyter-notebook[3838]: [I 20:53:52.037 NotebookApp] or http://127.0.0.1:8889/?token=...</span><br><span class="line">Nov 02 20:53:52 WORKSTATION jupyter-notebook[3838]: [I 20:53:52.037 NotebookApp] Use Control-C to stop this server and shut down all kernels (twice to skip confirmation).</span><br></pre></td></tr></table></figure></li></ol><h2 id="安装-Python-包、配置-Jupyter、Vim、IDE"><a href="#安装-Python-包、配置-Jupyter、Vim、IDE" class="headerlink" title="安装 Python 包、配置 Jupyter、Vim、IDE"></a>安装 Python 包、配置 Jupyter、Vim、IDE</h2><p>现在工作站基本上配置好啦!之后安装各种包,配置 Jupyter、Vim 和各种 IDE 就看各位的喜好啦。我用的 IDE 是 VS code(不想搞破解版 PyCharm),安装 Remote - SSH 就可以远程炼丹啦。</p><h2 id="Bonus-Benchmark"><a href="#Bonus-Benchmark" class="headerlink" title="Bonus: Benchmark"></a>Bonus: Benchmark</h2><p>工作站配置好了,来简单做一下 benchmark。作为对照的是 Google Colab,Colab 上的 GPU 是随机分配的,这次分配到的是 Tesla T4。脚本<a href="https://github.com/vincent507cpu/nlp_projects/blob/main/text%20classification/03%20transformers%20Colab%20GPU.ipynb" target="_blank" rel="noopener" title="Colab GPU">在此</a>。每个 epoch 用时大约 40 分钟。</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line">+-----------------------------------------------------------------------------+</span><br><span class="line">| NVIDIA-SMI 455.32.00 Driver Version: 418.67 CUDA Version: 10.1 |</span><br><span class="line">|-------------------------------+----------------------+----------------------+</span><br><span class="line">| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |</span><br><span class="line">| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |</span><br><span class="line">| | | MIG M. |</span><br><span class="line">|===============================+======================+======================|</span><br><span class="line">| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |</span><br><span class="line">| N/A 45C P8 11W / 70W | 0MiB / 15079MiB | 0% Default |</span><br><span class="line">| | | ERR! |</span><br><span class="line">+-------------------------------+----------------------+----------------------+</span><br></pre></td></tr></table></figure><p>然后分别使用一张 2060 和两张 2060 显卡在工作站上测试,脚本<a href="https://github.com/vincent507cpu/nlp_projects/blob/main/text%20classification/05%20transformers%20single%20GPU.ipynb" target="_blank" rel="noopener" title="single GPU">在此</a>。batch size 设置为 64 的话会爆显存,所以设置成了 32。单卡在运行时的显卡状态为</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">+-----------------------------------------------------------------------------+</span><br><span class="line">| NVIDIA-SMI 450.80.02 Driver Version: 450.80.02 CUDA Version: 11.0 |</span><br><span class="line">|-------------------------------+----------------------+----------------------+</span><br><span class="line">| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |</span><br><span class="line">| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |</span><br><span class="line">| | | MIG M. |</span><br><span class="line">|===============================+======================+======================|</span><br><span class="line">| 0 GeForce RTX 2060 Off | 00000000:08:00.0 On | N/A |</span><br><span class="line">| 90% 83C P2 163W / 170W | 5077MiB / 5926MiB | 96% Default |</span><br><span class="line">| | | N/A |</span><br><span class="line">+-------------------------------+----------------------+----------------------+</span><br><span class="line">| 1 GeForce RTX 2060 Off | 00000000:09:00.0 Off | N/A |</span><br><span class="line">| 46% 44C P8 7W / 170W | 9MiB / 5934MiB | 0% Default |</span><br><span class="line">| | | N/A |</span><br><span class="line">+-------------------------------+----------------------+----------------------+</span><br></pre></td></tr></table></figure><p>每个 epoch 用时大约 27 分钟。双卡在运行时的状态为</p><figure class="highlight bash"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">+-----------------------------------------------------------------------------+</span><br><span class="line">| NVIDIA-SMI 450.80.02 Driver Version: 450.80.02 CUDA Version: 11.0 |</span><br><span class="line">|-------------------------------+----------------------+----------------------+</span><br><span class="line">| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |</span><br><span class="line">| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |</span><br><span class="line">| | | MIG M. |</span><br><span class="line">|===============================+======================+======================|</span><br><span class="line">| 0 GeForce RTX 2060 Off | 00000000:08:00.0 On | N/A |</span><br><span class="line">| 92% 84C P2 162W / 170W | 5702MiB / 5926MiB | 93% Default |</span><br><span class="line">| | | N/A |</span><br><span class="line">+-------------------------------+----------------------+----------------------+</span><br><span class="line">| 1 GeForce RTX 2060 Off | 00000000:09:00.0 Off | N/A |</span><br><span class="line">| 88% 82C P2 161W / 170W | 3148MiB / 5934MiB | 34% Default |</span><br><span class="line">| | | N/A |</span><br><span class="line">+-------------------------------+----------------------+----------------------+</span><br></pre></td></tr></table></figure><p>每个 epoch 用时大约 17 分钟,脚本<a href="https://github.com/vincent507cpu/nlp_projects/blob/main/text%20classification/06%20transformers%20multiple%20GPU.ipynb" target="_blank" rel="noopener" title="multiple GPU">在此</a>。可以看到两张显卡的显存的利用不同(5702MiB vs 3146MiB),这种不平衡已经超过本文的讨论范围了。</p><hr><p>可以看到,Colab 唯一的优势在于显存比较大,算力被入门级的 2060 完爆。我的这台工作站待机时两个 GPU 功耗为 17W,24 小时待机也没有负担。根据谣言,3060 Ti 的性能与 2080 Super 相当,那么根据 Tom’s Hardware 的 <a href="https://www.tomshardware.com/reviews/gpu-hierarchy,4388.html" target="_blank" rel="noopener" title="GPU Benchmarks and Hierarchy: Graphics Cards Ranked">GPU 性能排行榜</a>,3060 Ti 相比 2060 预计有 50% 的性能提升,使用一张 3060 Ti 时每个 epoch 用时会在 20 分钟左右,相比 Colab 时间减少了一半。一个字,香!</p>]]></content>
<tags>
<tag> 深度学习工作站 </tag>
</tags>
</entry>
<entry>
<title>[工作站] 我的第一台个人深度学习工作站之硬件篇</title>
<link href="2020/11/01/%E5%B7%A5%E4%BD%9C%E7%AB%99-%E6%88%91%E7%9A%84%E7%AC%AC%E4%B8%80%E5%8F%B0%E4%B8%AA%E4%BA%BA%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E4%B9%8B%E7%A1%AC%E4%BB%B6%E7%AF%87/"/>
<url>2020/11/01/%E5%B7%A5%E4%BD%9C%E7%AB%99-%E6%88%91%E7%9A%84%E7%AC%AC%E4%B8%80%E5%8F%B0%E4%B8%AA%E4%BA%BA%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E4%B9%8B%E7%A1%AC%E4%BB%B6%E7%AF%87/</url>
<content type="html"><![CDATA[<h1 id="工作站配置"><a href="#工作站配置" class="headerlink" title="工作站配置"></a>工作站配置</h1><p>我的计划是使用入门级显卡配一台双 GPU 工作站使用两年,预留两年之内升级一次的空间,在保证性能和可扩展性的前提下,够用就好。经过若干次修改,最终的配置为:</p><a id="more"></a><ul><li>CPU:Ryzen 3 3100</li></ul><p>可能很多人觉得现在都 2020 年了,还用 4 核 CPU 太落伍。其实我想买 3600,但是没货,不得已买了 3100(不想多花钱买 3600X)。在深度学习中 4 核 CPU 搭配 2 块 GPU 已经足够了。不过在考虑明年升级成 5600。</p><ul><li>散热器:九州风神 GAMMAXX 200T</li></ul><p>虽然 CPU 的原装散热器已经可以满足需要,我买它的原因是有折扣,就算为明年换 5600 提前买好了。</p><ul><li>主板:Asus Prime X570-Pro</li></ul><p>随便一块 X570 主板基本都有两条 PCI-E 插槽(ITX 小主板除外),不过通常是 16-4 通道的配置。虽然 PCI-E 4.0 下 x4 应该不是瓶颈,用起来心里还是有点不舒服。所以选择了支持 SLI 的主板中最便宜的一块。</p><ul><li>内存:OLOy DDR4 RAM 16GB (2x8GB) 3200 MHz </li></ul><p>OLOy 是一个新牌子,买它完全是因为便宜。<a href="https://www.tomshardware.com/features/oloy-ram-should-you-buy" target="_blank" rel="noopener" title="Should You Buy OLOy RAM?">Tom Hardware 认为它的产品质量还可以</a>,先买来试试。</p><ul><li>显卡:两块 EVGA 06G-P4-2066-KR GeForce RTX 2060 KO Gaming</li></ul><p>我的目标是 3060 Ti,但是还没有发布,所以先用两块 2060 当亮机卡。</p><ul><li>硬盘:Crucial P1 500GB</li></ul><p>这块硬盘的性能很一般,买它因为折扣很大(超过一半),而且再不济也是一块 NVMe 硬盘,性能与顶级 PCI-E 3.0 NVMe 硬盘最多差 10%,完全够用。</p><p>还有一块差不多 10 年前买的 3T 硬盘,里面有很多以前的文件,现在读不出来了,把外壳拿掉直接插电脑上看看能不能读出来。</p><ul><li>电源:Thermaltake ToughPower 750W 80 Plus Gold Semi Modular Power Supply</li></ul><p>买它的原因也是有折扣,而且 750W 带 5600 + 两块 3060 Ti 应该也没问题。</p><p>显示器、键盘、鼠标用现成的,机箱买的最便宜的,另外买了一个 150M 的 WiFi 接收器。这就是全部配件了:<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricicbxcnd3rovL3BCribuely2RckeeQzvoibo2xCMySB5rcKUzvv8HCeE2UQ/0?wx_fmt=jpeg" alt=""></p><p>关于配件的选择请参考:</p><ul><li><a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485308&idx=1&sn=fc687999f159fed9614688e55d948c58&chksm=cf06b640f8713f56ff63c20acebd8c846a517be81a8fdbfae21062bc96819046630238dd0381&token=1239929154&lang=zh_CN#rd" target="_blank" rel="noopener">《2020 年 10 月的多 GPU 深度学习工作站配置指南》</a></li><li><a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485265&idx=1&sn=9da954f0dd5c00884a6b3b92d7f2349a&chksm=cf06b66df8713f7b7ff0200a8e3b7ab324f3e6a72a0355cd5f9d0479cf802e4bb5df29d21ebc&token=1239929154&lang=zh_CN#rd" target="_blank" rel="noopener">《2020 年 10 月的单 GPU 深度学习工作站配置指南》</a><h1 id="升级计划"><a href="#升级计划" class="headerlink" title="升级计划"></a>升级计划</h1></li></ul><h2 id="第一次升级:2021-年上半年"><a href="#第一次升级:2021-年上半年" class="headerlink" title="第一次升级:2021 年上半年"></a>第一次升级:2021 年上半年</h2><p>把 CPU 升级到 5600,如果 3100 用的还行,就不升级了。</p><p>传说中 NVIDIA 计划在 12 月份发布 16GB 显存的 3070 Super,不过现在有消息说计划已经取消。那就把显卡换成两块 8GB 显存的 30 系列的最低款(目前可能是 3060 Ti)。内存升级到 32GB。</p><h2 id="第二次升级:2022-年下半年"><a href="#第二次升级:2022-年下半年" class="headerlink" title="第二次升级:2022 年下半年"></a>第二次升级:2022 年下半年</h2><p>如果两年以后进步很大,这台电脑不能满足需要了,就再升级一次。这应该是一台全新的工作站,目标是高性能,所有配件都要换,应该可以使用五年。</p><ul><li>CPU:Ryzen 9 5950 的下一代</li><li>内存:64 ~ 128 GB</li><li>显卡:2 块带水冷的 3090 显卡</li><li>硬盘:1 块 2T SSD 硬盘</li><li>电源:1200W 电源</li><li>机箱:具有风道设计的机箱</li></ul><h1 id="装机过程"><a href="#装机过程" class="headerlink" title="装机过程"></a>装机过程</h1><p>首先要了解需要连接哪些线,可以先看说明书。这台工作站没有任何灯效,所以少了一点麻烦。<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZriciczCE9DMrTFk0RX1LTnvRy8LEFoyLvhbibaSRG7J4zaia8JicTt9Htcbjqg/0?wx_fmt=png" alt=""><br>主板上需要连接的线有:</p><ul><li>24 pin 主板供电(图中右侧 1)</li><li>8 pin CPU 供电(图中上侧 1)</li><li>4 pin CPU 风扇供电(图中上侧 3)</li><li>4 pin 机箱风扇供电(图中中部左侧 3)</li><li>机箱前面板 USB 3 接口(图中中部右侧 9)</li><li>机箱控制面板排针(图中下方右侧 14)</li><li>机箱前面板 USB 2.0 接口(图中下方中部 18)</li><li>机箱音频连接排针(虽然我不用工作站的声音,还是接上了。图中下方左侧 21)</li></ul><p>还有两张显卡需要两条 8 pin 供电线,与电源相连。这些插口都有防呆设计,插反了是插不进去的。</p><h2 id="CPU"><a href="#CPU" class="headerlink" title="CPU"></a>CPU</h2><p>先把拨杆拉开。主板的 CPU 插槽和 CPU 的一角上都有一个三角标记,对准了把 CPU 放下去,拉下拨杆。拉下拨杆的时候有一点阻力,稍微用力一点就可以了。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricicjib6bJRUiajNISkZQTwvsNN3SRRic6cRmibcZusynv4HqqSIcd7fQEyNAg/0?wx_fmt=jpeg" alt=""><br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricicEIqyHQOzKkNib5fBfq4CBOccKQ5Lh0ulZNqiceUBc2YORr7ibx82pkOwQ/0?wx_fmt=jpeg" alt=""></p><h2 id="内存"><a href="#内存" class="headerlink" title="内存"></a>内存</h2><p>现在的电脑都是双通道内存设计,假如有 4 个内存插槽需要插两根内存,要以 1-3 或 2-4 这样插。扳开卡扣,看准了方向把内存插紧,听到“咔”的一声就插好了。插反了是插不进去的。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricicSyXS2WrgXD807y9zSIeEoBMAzzuJ4WZjQic8hIQwHvNkh8o2l7kfHWQ/0?wx_fmt=jpeg" alt=""></p><h2 id="SSD-硬盘"><a href="#SSD-硬盘" class="headerlink" title="SSD 硬盘"></a>SSD 硬盘</h2><p>有的主板上有一些 SSD 插槽有扇热片,需要先把扇热片取下来。SSD 硬盘一般是 2280 尺寸,对应离插槽第 3 远的螺丝。先要拧一颗加高螺丝(用手就可以拧),然后将 SSD 硬盘插进去,拧上固定螺丝。扇热片上一般有散热胶贴纸,先把贴纸撕下来,再把扇热片放回去。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricic2DP0jVKBE73ib6uv5KK1aRJVQKC5E1okzOiaZjUia23E1KdcfmR2NWoyQ/0?wx_fmt=jpeg" alt=""></p><h2 id="CPU-散热器"><a href="#CPU-散热器" class="headerlink" title="CPU 散热器"></a>CPU 散热器</h2><p>CPU 插槽的两边有两个卡扣,将 CPU 散热器卡在上面就可以了,需要用点力气。需要插第一根供电线了。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricicC8iaBhatAibTgWvib0PcHHoWABWtWBrZWcmzvcjLnjc9Sib0G52cdSwHqg/0?wx_fmt=jpeg" alt=""></p><h2 id="把电源放进机箱里并固定"><a href="#把电源放进机箱里并固定" class="headerlink" title="把电源放进机箱里并固定"></a>把电源放进机箱里并固定</h2><p>在把电源放进机箱前可以把要用的线先接上。电源分全模组、半模组和非模组三种。非模组电源是所有线事先连接在电源上,全模组电源是所有线都根据需要连接,半模组电源介于二者之间。我的电源是半模组电源,需要再连接两条 8 pin 供电线。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZriciciabLYYOW9EGv7l3zLW1dzSHWHKXKYaliciaOzqKo7zDnLeD5UddsvKbtA/0?wx_fmt=jpeg" alt=""><br>现在可以把电源放进机箱了。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricicNu3n87EzBWzBEv7VYib4POIZ6ia45lYErb0qB7aISAW8qGfg1JvYl6hg/0?wx_fmt=jpeg" alt=""></p><h2 id="把主板放进机箱并固定"><a href="#把主板放进机箱并固定" class="headerlink" title="把主板放进机箱并固定"></a>把主板放进机箱并固定</h2><p>把主板放进机箱前需要先看看那些螺丝孔需要基座螺丝,需要的话要先拧上。我的这张主板一共需要固定 9 颗螺丝,上中下各 3 颗。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricic3sU0N2zia4l0AIv1kWonDdwFfU1Y1HBSoE0owQxkI0ZUbibnK3gAqpMQ/0?wx_fmt=jpeg" alt=""></p><h2 id="连接所有信号线和供电线"><a href="#连接所有信号线和供电线" class="headerlink" title="连接所有信号线和供电线"></a>连接所有信号线和供电线</h2><p>这应该是整个过程里最麻烦的一步了,然而说明书和接头上都有标记,耐心一点一根根插上就行了。</p><h2 id="插上显卡并连接供电"><a href="#插上显卡并连接供电" class="headerlink" title="插上显卡并连接供电"></a>插上显卡并连接供电</h2><p>首先把机箱上对应位置的挡板拿下去,把 PCI-E 插槽右侧的卡扣扳下去,把显卡插上听见“咔”的一声就行了。然后在显卡上插上对应数量的供电线。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricicfrAbfbmeSpK2JDRogxttZ5za1Vy8Q7hUVCBFKR9WP8YvGHzOadP8gQ/0?wx_fmt=jpeg" alt=""></p><h2 id="理线"><a href="#理线" class="headerlink" title="理线"></a>理线</h2><p>这一步不是必须的,但是规整的内部空间有利于风道的畅通。一般电源和机箱会送几根 zip tie 或者尼龙扣,用它们把线捆在一起就可以了。正面的理线在上图里有,下面是背部的理线:<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricicTTbassaXdRFvoyVYaAKFlopmhdCsOClhPa1rYiaiaO0oo6UIKke8WcLA/0?wx_fmt=jpeg" alt=""><br>至此电脑就组装完了,很简单吧?在合上机箱盖之前可以先把电脑点亮测试一下:<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHZxA0uGf9GCaiaFIDSmKk6y9mWic9JiawV9QSmK2PaLzFI6dictjicdal9O7I6VktremGYt3pykpAgqsg/0?wx_fmt=jpeg" alt=""><br>这绿油油的光好像魔兽世界西瘟疫之地上的绿光。如果连接到显示器上,应该可以看到开机自检(这里显示的 BIOS):<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkGDIRyOlnx8uLyibZtHgZricicicAt8umrJTay2EXoib3Ficf2cMMvBehsLGKNsRr6icSKbTxib78pLs7Kxdw/0?wx_fmt=jpeg" alt=""><br>至此工作站就组装完毕了,下一篇来谈谈环境设置。</p>]]></content>
<tags>
<tag> 深度学习工作站 </tag>
</tags>
</entry>
<entry>
<title>[工作站] 2020 年 10 月的多 GPU 深度学习工作站配置指南</title>
<link href="2020/10/10/%E5%B7%A5%E4%BD%9C%E7%AB%99-2020-%E5%B9%B4-10-%E6%9C%88%E7%9A%84%E5%A4%9A-GPU-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E9%85%8D%E7%BD%AE%E6%8C%87%E5%8D%97/"/>
<url>2020/10/10/%E5%B7%A5%E4%BD%9C%E7%AB%99-2020-%E5%B9%B4-10-%E6%9C%88%E7%9A%84%E5%A4%9A-GPU-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E9%85%8D%E7%BD%AE%E6%8C%87%E5%8D%97/</url>
<content type="html"><![CDATA[<p>本文接上一篇<a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485265&idx=1&sn=9da954f0dd5c00884a6b3b92d7f2349a&chksm=cf06b66df8713f7b7ff0200a8e3b7ab324f3e6a72a0355cd5f9d0479cf802e4bb5df29d21ebc&token=1032059941&lang=zh_CN#rd" target="_blank" rel="noopener">《2020 年 10 月的单 GPU 深度学习工作站配置指南》</a>,探讨多 GPU 工作站的搭建。很多在单 GPU 工作站中不甚重要的因素在多 GPU 工作站中变得举足轻重。</p><a id="more"></a><p>本文主要参考了以下文章:</p><ul><li><a href="https://timdettmers.com/2020/09/07/which-gpu-for-deep-learning/" target="_blank" rel="noopener" title="Which GPU(s) to Get for Deep Learning: My Experience and Advice for Using GPUs in Deep Learning">Which GPU(s) to Get for Deep Learning: My Experience and Advice for Using GPUs in Deep Learning</a></li><li><a href="https://timdettmers.com/2018/12/16/deep-learning-hardware-guide/" target="_blank" rel="noopener" title="A Full Hardware Guide to Deep Learning">A Full Hardware Guide to Deep Learning</a></li><li><a href="https://lambdalabs.com/blog/deep-learning-hardware-deep-dive-rtx-30xx/#blower-gpus" target="_blank" rel="noopener" title="Deep Learning Hardware Deep Dive – RTX 3090, RTX 3080, and RTX 3070">Deep Learning Hardware Deep Dive – RTX 3090, RTX 3080, and RTX 3070</a></li><li><a href="https://www.howtogeek.com/365215/what’s-the-difference-between-a-blower-and-an-open-air-gpu-cooler/" target="_blank" rel="noopener" title="What’s the Difference Between a Blower and an Open-Air GPU Cooler?">What’s the Difference Between a Blower and an Open-Air GPU Cooler?</a></li></ul><p>搭建多 GPU 工作站的要点是避免显卡过热与电源过载,其它很多方面与单 GPU 工作站的原则相似,没有提到的方面(包括显卡的选择)请参考<a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485265&idx=1&sn=9da954f0dd5c00884a6b3b92d7f2349a&chksm=cf06b66df8713f7b7ff0200a8e3b7ab324f3e6a72a0355cd5f9d0479cf802e4bb5df29d21ebc&token=1032059941&lang=zh_CN#rd" target="_blank" rel="noopener">《2020 年 10 月的单 GPU 深度学习工作站配置指南》</a>。</p><h1 id="双-GPU-工作站"><a href="#双-GPU-工作站" class="headerlink" title="双 GPU 工作站"></a>双 GPU 工作站</h1><h2 id="PCI-E-带宽"><a href="#PCI-E-带宽" class="headerlink" title="PCI-E 带宽"></a>PCI-E 带宽</h2><p>随着 GPU 的增加,模型训练的并行程度和 GPU 之间的数据传输增加,PCI-E 带宽变得越来越重要。然而对于双 GPU 工作站来说,PCI-E 带宽的重要性仍然有限。已经有人对 PCI-E 3.0 下 x16 和 x8 通道进行过测试<a href="https://www.pugetsystems.com/labs/hpc/PCIe-X16-vs-X8-with-4-x-Titan-V-GPUs-for-Machine-Learning-1167/" target="_blank" rel="noopener" title="PCIe X16 vs X8 with 4 x Titan V GPUs for Machine Learning"></a>,结论是影响非常小。那么在 x4 甚至 x2 或 x1 时带宽对深度学习有影响吗?目前还不清楚。</p><p>一个现实是 CPU 拥有的 PCI-E 通道是有限的:</p><table><thead><tr><th>CPU</th><th>支持 PCI-E 等级</th><th>通道数</th></tr></thead><tbody><tr><td>Ryzen 3000/5000</td><td>4.0</td><td>24</td></tr><tr><td>Core</td><td>3.0</td><td>20</td></tr></tbody></table><p>而有限的通道中至少要给 NVME 存储器分配 4~8 个通道。既然 x8 通道对深度学习没什么影响,双 GPU 完全可以使用双 x8 通道。这里支持 PCI-E 4.0 的优势显示出来了,一个 x8 PCI-E 4.0 通道相当于一个 x16 PCI-E 3.0 通道(30 系列显卡才支持 PCI-E 4.0)。双 x8 PCI-E 通道并联被 NVIDIA 称为 SLI 技术,高端芯片组 X570 和 Z490 都支持 SLI,所以在买主板的时候留意是否支持 SLI 就可以了。根据我的经验,只要主板上的两个 PCI-E 插槽都有金属包装,很可能就支持 SLI:</p><p><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHD2fLPmuYNV1VI4qqLibvOuFbMuB1RCg1GmLia97uvRib5E8oyiafyBcmn3aEDqax0VEpNhWtia1AzRmg/0?wx_fmt=jpeg" alt="GIGABYTE X570 AORUS PRO:请注意上面两个 PCI-E x16 插槽有金属包装"></p><p>支持 SLI 的主板上如果还有第三个 PCI-E x16 插槽,这个插槽的通道要么走主板芯片要么与第二个插槽平分通道。比如上图,如果在最上面的两个插槽的任意一个中插一张卡,则为 x16 通道;在上面两个插槽插两张卡,则为 8-8 通道;三个插槽都插卡,则为 8-8-4 或 8-4-4 通道。</p><p>X570 主板中支持双路 x8 PCI-E 的型号有:</p><ul><li>ASRock X570 Phantom Gaming X</li><li>ASRock X570 Creator</li><li>ASRock X570 TAICHI</li><li>ASUS PRIME X570-PRO</li><li>ASUS AMD AM4 ROG Crosshair VIII Hero</li><li>ASUS ROG Strix X570-E Gaming</li><li>ASUS ROG Strix X570-F Gaming</li><li>GIGABYTE X570 AORUS PRO</li><li>GIGABYTE X570 AORUS ULTRA</li><li>GIGABYTE X570 AORUS MASTER</li><li>GIGABYTE X570 AORUS XTREME</li><li>MSI MEG X570 ACE Gaming</li><li>MSI MEG X570 UNIFY</li><li>MSI MEG X570 GODLIKE</li></ul><p>MSI MEG X570 GODLIKE 有 4 个 x16 PCI-E 插槽,前三个可以以 8-4-4 通道数连接;第四个 PCI-E 插槽走主板芯片以 4 条通道连接(8-4-4-4)。</p><p><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkHu2GuyEJXjiaRyiaVPbc4m1nV5AiaVW1QqRkLgBlpiaU9SUZukK8IvzqiaOL0rdpJkwnDraCJeGcASITA/0?wx_fmt=png" alt="MSI MEG X570 GODLIKE"></p><p>Z490 主板中支持双路 x8 PCI-E 的型号有:</p><ul><li>ASRock Z490 Taichi</li><li>ASUS ProART Z490-CREATOR</li><li>ASUS ROG STRIX Z490-E GAMING</li><li>ASUS ROG MAXIMUS XII APEX</li><li>ASUS ROG MAXIMUS XII FORMULA</li><li>GIGABYTE Z490 VISION</li><li>GIGABYTE Z490 AORUS PRO AX</li><li>GIGABYTE Z490 AORUS ULTRA</li><li>GIGABYTE Z490 AORUS MASTER</li><li>GIGABYTE Z490 AORUS ULTRA</li><li>MSI MPG Z490 GAMING CARBON</li><li>MSI MEG Z490 UNIFY</li><li>MSI MEG Z490 ACE</li></ul><p>X570 和 Z490 芯片组是最高端的芯片组,比 B550 和 B460 贵一些;支持 SLI 的功能算是进阶设计,价格要更贵一些。</p><h2 id="CPU、内存、电源的选择"><a href="#CPU、内存、电源的选择" class="headerlink" title="CPU、内存、电源的选择"></a>CPU、内存、电源的选择</h2><ul><li>理论上 4 核 CPU 足够,如果有很多预处理任务也可以买 6 核的 3600 和 10400F 或者 8 核的 3700x 和 10700F,再多就没有必要。</li><li>内存的大小看实际需求和 pipeline 设计,要么不小于<strong>单卡显存 + 6~8G</strong>,要么不小于<strong>显存之和 + 6~8G</strong>。</li><li>如果使用 4 核 CPU 配两张 3070 显卡,可选 750W 或 850W 电源;如果使用 6 核 CPU 配两张 3080/3090 显卡,至少要使用 1000W 电源。</li></ul><h2 id="散热"><a href="#散热" class="headerlink" title="散热"></a>散热</h2><p>如果安装两块 3070,发热与两块 2080 Ti 差不多,散热应该不是大问题;如果安装两块 3080 或 3090,请参考下面的散热部分。</p><h1 id="三-GPU-工作站"><a href="#三-GPU-工作站" class="headerlink" title="三 GPU 工作站"></a>三 GPU 工作站</h1><h2 id="PCI-E-带宽-1"><a href="#PCI-E-带宽-1" class="headerlink" title="PCI-E 带宽"></a>PCI-E 带宽</h2><p>如果希望三张卡都有至少 x8 带宽,Core 和 Ryzen 就不能满足了,必须是 Core X-Series,Xeon,Threadripper 或者 EPYC。我对 Xeon 和 EPYC 完全不了解,此处略。</p><table><thead><tr><th>CPU</th><th>支持 PCI-E 等级</th><th>通道数</th></tr></thead><tbody><tr><td>Threadripper</td><td>4.0</td><td>64</td></tr><tr><td>10 代 Core X-Series</td><td>3.0</td><td>48</td></tr></tbody></table><p>若主板上有三个 PCI-E 插槽,Intel X299 和 AMD sTRX40 主板都支持 16-8-16 分配;若有第四个插槽,sTRX40 可以支持 16-8-16-8 分配,而 X299 支持 8-8-8-8 分配。此处 AMD 的优势又体现出来了,不要说 Threadripper 支持更多的 PCI-E 通道,而且 PCI-E 4.0 x8 已经相当于全速 PCI-E 3.0 x16。Threadripper 唯二的缺点是贵和功耗大(然而未必比 Core X-Series 的满载功耗更大)。</p><p>ASRock TRX40 TAICHI 主板支持 16-16-16 通道分配,是 Threadripper 的最佳搭配。</p><p><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHD2fLPmuYNV1VI4qqLibvOuia5DwicJDdticKcaBkz2rbU04ApTqn5ntJxibSWX5Z0Mbem5yicNSGXTtcw/0?wx_fmt=jpeg" alt=""></p><h2 id="供电"><a href="#供电" class="headerlink" title="供电"></a>供电</h2><p>常见的 CPU 与 GPU 的热设计功率(TDP)为:</p><table><thead><tr><th>CPU</th><th>TDP</th></tr></thead><tbody><tr><td>Threadripper</td><td>280W</td></tr><tr><td>Core X-Series</td><td>165W</td></tr></tbody></table><p>新 30 系列 GPU 的热设计功率为:</p><table><thead><tr><th>GPU</th><th>TDP</th></tr></thead><tbody><tr><td>RTX 3090</td><td>350W</td></tr><tr><td>RTX 3080</td><td>320W</td></tr><tr><td>RTX 3070</td><td>220W</td></tr></tbody></table><p>如果使用 Core 10920X 搭配三块 3070,推荐 1000W 电源;其它搭配推荐 1500W 电源。</p><h2 id="散热-1"><a href="#散热-1" class="headerlink" title="散热"></a>散热</h2><p>GPU 到了三块,散热开始需要重视,不然显卡会因为过热自动降频。显卡的散热方式有风冷和水冷两种,风冷又分涡轮式散热(blower)和开放式两种(open-air)两种。</p><ul><li>开放式散热:由风扇吸入冷空气,冷空气在散热片上进行热交换,热空气在 GPU 的周围排出。<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkHu2GuyEJXjiaRyiaVPbc4m1njCgJuv2dI8kKZiapLjsJuriaxfGcChwGcYlthiaribgicrmO7ZoXNLuIvibg/0?wx_fmt=png" alt=""></li><li>涡轮式散热:整个 PCB 板被包裹起来,冷空气被风扇吸入后在散热片上进行热交换后在 GPU 后挡板处排出。<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkHu2GuyEJXjiaRyiaVPbc4m1nwK6j7GntzNUHFtVlibfvvicERAaROXGFyKXgrMwHu69rCica4DiapWEquQ/0?wx_fmt=png" alt=""></li><li>水冷散热:冷水被水泵抽到芯片上吸收芯片的热量,热水随后被抽到散热片与冷空气进行热交换。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHcWd8F0HgaTHaTkAP1opu0jlexAribzFEzmGCI0xtrN22pY5B5RpHrVFkVHrQB2p0OedGq5o9CqNg/0?wx_fmt=jpeg" alt=""></li></ul><p>使用开放式散热的显卡会面临热空气被其它显卡吸收的问题,会降低散热的效果,极端情况下会造成显卡过热自动降频,从而降低性能。如果显卡之间有超过 1 个 PCI-E 空位,则基本不会存在散热的问题,但是这样由于空间的限制可能仅可以使用双卡;对于三卡工作站而言,涡轮式散热显卡或水冷散热显卡是必需的,然而是否可行仍需实践。</p><h2 id="风道与机箱的选择"><a href="#风道与机箱的选择" class="headerlink" title="风道与机箱的选择"></a>风道与机箱的选择</h2><p>当使用了 3 块以上的 GPU 以后,机箱的风道变得很重要,否则热空气会在机箱内积累,一样会造成显卡过热。一款合适的深度学习服务器机箱应该有充足的内部空间和足够多放风扇的位置。我推荐两款机箱:</p><ul><li>Thermaltake Core X71<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHu2GuyEJXjiaRyiaVPbc4m1nxzP0vIGpX1uvpa7hfIWPCst1McuSibTkRSXWIWM8ROCeXVcVQXSiaKBg/0?wx_fmt=jpeg" alt=""></li></ul><p>这个机箱的优点是可以装下足够多的风扇(上面 3 个,前面 2 个,下面 3 个,后面 1 个),非常适合多个水冷设备。</p><ul><li>Corsair Carbide Series Air 540<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHu2GuyEJXjiaRyiaVPbc4m1nG5sYicnklDFYxRibBMbo5ucYaQR1wX0kicS2MQkxuTmpKAXLeBnIR1PNw/0?wx_fmt=jpeg" alt=""></li></ul><p>这个机箱的优点是内部空间非常充足,可以安装风扇的位置也不少(上面 3 个,前面 2 个,后面 1 个)。</p><h2 id="显卡选择"><a href="#显卡选择" class="headerlink" title="显卡选择"></a>显卡选择</h2><p>如果显卡之间有足够的空间,那么可以使用开放式散热显卡;3 块以上显卡空间有限,需要使用涡轮式散热显卡或水冷显卡。现在各个厂商只发布了开放式散热设计的显卡,下面的型号可能还没有公开发布:</p><ol><li>涡轮式散热显卡:</li></ol><ul><li>GIGABYTE MSI GeForce RTX 3090 TURBO 24G<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkHD2fLPmuYNV1VI4qqLibvOuUu4grYmKoVK8ibGX8uG6vOG9icU4OiakMnMmRY78jDYITHgUdchU7A5qA/0?wx_fmt=png" alt=""></li></ul><ol start="2"><li>水冷散热显卡:</li></ol><ul><li>Colorful iGame Neptune GeForce RTX 30 系列<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkHD2fLPmuYNV1VI4qqLibvOuXWWBtiaBIGyLLLFJOdIgvicfyibzIMdEu4VH4ewV0ZpLtcdyALLdCw7Og/0?wx_fmt=png" alt=""></li><li>EVGA GeForce RTX 3080 10GB HYDRO COPPER<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHD2fLPmuYNV1VI4qqLibvOuUZibcNM1cWTl5kHDliaQIAGHHJpoqDPIeicd9iboNPjlxBo1HuUCkchqsg/0?wx_fmt=jpeg" alt=""></li><li>EVGA GeForce RTX 3090 KINGPIN Hybrid<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkHD2fLPmuYNV1VI4qqLibvOunK2Bdu2oUgcqrg6UcWnO5vmNg0RDib8OFeBeI7BY3XFx03CE3HiavBrQ/0?wx_fmt=png" alt=""></li></ul><h1 id="四-GPU-工作站"><a href="#四-GPU-工作站" class="headerlink" title="四 GPU 工作站"></a>四 GPU 工作站</h1><h2 id="供电-1"><a href="#供电-1" class="headerlink" title="供电"></a>供电</h2><p>如果使用四张显卡,应该把主机放在专业机房内;在普通民用环境中目前只可能使用四张 3070 显卡,推荐 1500W 电源。</p><p>美国电脑供应商 Puget Systems 近期发表了一篇研究搭建一台<a href="https://www.pugetsystems.com/labs/articles/Quad-GeForce-RTX-3090-in-a-desktop---Does-it-work-1935/" target="_blank" rel="noopener" title="Quad GeForce RTX 3090 in a desktop - Does it work?">拥有 1~4 张 GIGABYTE MSI GeForce RTX 3090 TURBO 24G 显卡的工作站</a>的可能性的博客。当使用 4 块 3090 显卡时,使用了双 1600W 供电。在美国,3 块 3090 已经接近了普通民用电路的供电极限。<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHu2GuyEJXjiaRyiaVPbc4m1nlvY1aZuSEpILicYCiaqh3HzmiaNxlbqiafmFYgFFwAUEU9Nq9DIJV8LKXQ/0?wx_fmt=jpeg" alt=""></p><h2 id="主板的选择"><a href="#主板的选择" class="headerlink" title="主板的选择"></a>主板的选择</h2><p>如果使用四张显卡,最好每张显卡都有 8 条通道。对于 Threadripper 来说,目前唯一的选择是 Gigabyte TRX40 DESIGNARE Motherboard:<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHD2fLPmuYNV1VI4qqLibvOuk1REyycb9awpQyFMegvc2rNibHrIEjKwu3UsxeE9hGicdPTt0a3KAbfQ/0?wx_fmt=jpeg" alt=""></p><p>而对于 Core X-Series 来说,可以选择以下主板:</p><ul><li>GIGABYTE X299X AORUS MASTER(8-8-8-8 通道)</li><li>MSI Creator X299 LGA(8-8-16-8 通道)</li><li>MSI MEG X299 CREATION(8-8-16-8 通道)</li><li>EVGA X299 DARK(8 x 3 + 16 x 2 通道)</li></ul><p>还有两张主板有 7 个 PCI-E 插槽,因为有桥接芯片,支持 4 路 x16 PCI-E 通道:</p><ul><li>GIGABYTE X299-WU8<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHD2fLPmuYNV1VI4qqLibvOuB11KH3PlVCYh5o2FqN3Z4piagwxDAwAhqHAzic7R5SLqpMM8u5YcOIJg/0?wx_fmt=jpeg" alt=""></li><li>ASUS WS X299 SAGE<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkHD2fLPmuYNV1VI4qqLibvOu1ibrHcTfTLw3KJbbmVEUDq6aY2Jg06SRZiaRCyDpC1uI7WCDXAkmRJEA/0?wx_fmt=jpeg" alt=""></li></ul><h2 id="CPU-与内存"><a href="#CPU-与内存" class="headerlink" title="CPU 与内存"></a>CPU 与内存</h2><p>Threadripper 是 24 核起,Core X-Series 是 12 核起,配 4 张 GPU 足够用了。</p><p>内存请参考双 GPU 部分。</p><hr><p>现在是购买 RTX 30 系列显卡的好时候吗?我认为不是。</p><ol><li>现在根本买不到啊;</li><li>深度学习框架对新 CUDA 和 CuDNN 的支持还不够;</li><li>各个厂家的显卡还没有开发完全;</li><li>新显卡的散热效果有待观察。</li></ol><p>NVIDIA 已经说了,目前的缺货会延续到 2021 年。我们还是耐心等待吧。另外也希望 Big Navi 的性能和供货给力,让本来打算买 N 卡的人去买 A 卡,给我们深度学习民工一条生路啊。</p>]]></content>
<tags>
<tag> 深度学习工作站 </tag>
</tags>
</entry>
<entry>
<title>[工作站] 新 RTX 3090 搭建深度学习工作站的一些思考</title>
<link href="2020/10/10/%E5%B7%A5%E4%BD%9C%E7%AB%99-%E6%96%B0-RTX-3090-%E6%90%AD%E5%BB%BA%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E7%9A%84%E4%B8%80%E4%BA%9B%E6%80%9D%E8%80%83/"/>
<url>2020/10/10/%E5%B7%A5%E4%BD%9C%E7%AB%99-%E6%96%B0-RTX-3090-%E6%90%AD%E5%BB%BA%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E7%9A%84%E4%B8%80%E4%BA%9B%E6%80%9D%E8%80%83/</url>
<content type="html"><![CDATA[<p>今天在班上看完了 NVIDIA 的 GeForce RTX 30 系列发布会。看完感觉游戏玩家应该做梦都会笑醒:</p><table><thead><tr><th></th><th>RTX 3070</th><th>RTX 3080</th><th>RTX 3090</th><th>RTX 2080 Ti</th></tr></thead><tbody><tr><td>CUDA Core</td><td>5888</td><td>8704</td><td>10496</td><td>4352</td></tr><tr><td>Core Clock</td><td>1500 Mhz</td><td>1440 Mhz</td><td>1400 Mhz</td><td>1350 Mhz</td></tr><tr><td>Boost Clock</td><td>1730 Mhz</td><td>1710 Mhz</td><td>1700 Mhz</td><td>1545 Mhz</td></tr><tr><td>Memory Capacity</td><td>8 GB DDR6</td><td>10 GB DDR6X</td><td>24 GB DDR6X</td><td>11 GB DDR6</td></tr><tr><td>Memory Bus</td><td>256 bit</td><td>320 bit</td><td>384 bit</td><td>352 bit</td></tr><tr><td>Memory Speed</td><td>16 Gbps</td><td>19 Gbps</td><td>19.5 Gbps</td><td>14 Gbps</td></tr><tr><td>Memory Bandwidth</td><td>512 Gbps</td><td>760 Gbps</td><td>936 Gbps</td><td>616 Gbps</td></tr><tr><td>TDP</td><td>220w</td><td>320W</td><td>350W</td><td>275W</td></tr><tr><td>MSRP</td><td>$499 US</td><td>$699 US</td><td>$1499 US</td><td>$999 US</td></tr></tbody></table><a id="more"></a><p>一张中端卡吊打上代卡皇:<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkEPVr1buRIXzb1FIicicJWBcW2ZmupLZREtGLqic6LdgVEh2Jqahxzjotl204JUVptia0HICk0NTVLCzQ/0?wx_fmt=png" alt=""><br>新一代卡皇 RTX 3090 据老黄说可以以 8k 分辨率全开特效无压力玩任何游戏。苏妈现在应该是压力山大吧。</p><p>RTX 30 系列对游戏玩家是一个巨大的提升,但对深度学习研究者呢?假设想组一台 4 块 RTX 3090 的服务器,个人认为现在现有的 3090 显卡是不现实的。原因有三:</p><ol><li>功耗是一个巨大的隐患</li></ol><p>普通居民区、写字楼内的单根电线的最大载荷约为 2000W,超过这个数字会跳闸,用电器必须放在专门设计的机房里。以前 RTX 2080 Ti 在超频时最大功耗为不到 350W,四个 RTX 2080 Ti 加上 CPU 的功耗接近 2000W,还可以放在办公室里,而 RTX 3090 的 TDP 已经是 350W。NVIDIA 专门设计了一个 12 pin 的供电接口,最大载荷暂时未知,但有传言说最大载荷为 600W。而 RTX 2080 Ti 的供电接口为双 8 pin,其最大载荷为 150W (单个 8 pin 接口的供电)* 2 + 75W(PCIE 插槽供电)= 375W。目前已知的 12 pin 接口均为双 8 pin 转接而来,其最大载荷不会超过 375W;一旦电脑电源原生支持 12 pin 接口,功耗就难说了。</p><ol start="2"><li>版型太大,主板上容纳不下</li></ol><p>GeForce RTX 3090 is massive…这可能是第一张单卡三插槽的显卡,PCB 板在中间,前后都有风扇。<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkEeWLg0UUrtsztYiasVDzTXABFqoPQKnazZ2kE8mWQCHoBaRj7kU1hHdJ9uwibp1bK6XlgfDic3fa7zw/0?wx_fmt=png" alt=""><br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkEPVr1buRIXzb1FIicicJWBcWu8DdXpE8v7U7BrHs6S0IicAhWhibA8R43xoBuDrvF0NcVwHzCeosGicbQ/0?wx_fmt=jpeg" alt="RTX 3090 与 RTX 2080 Ti 对比"><br>三插槽就有问题了:主板上没位置插,机箱里也装不下。比如下面的一张主板:<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkEPVr1buRIXzb1FIicicJWBcWk7LslqLFYqsTgNwBG91kKM2ZZeL5HH1ctjF97Via4QpbX070Y9BMpCg/0?wx_fmt=png" alt=""><br>一般的双槽显卡可以插四张,而 RTX 3090 只能插两张(1x 和 3x),因为一张 RTX 3090 需要两边都有空位。就算主板上有充足的空间,一般机箱上只有 8 个 PCI-E 槽位。所以想要 4 x RTX 3090,需要更大的机箱(EATX 或者 WATX)和专用的主板,或者希望以后会出双卡槽的 RTX 3090。</p><ol start="3"><li>散热问题</li></ol><p>如果只有 1 张显卡,那么散热不是主要问题;而如果想放多张显卡,那么显卡一定要是涡轮散热设计。话说回来,显卡的风冷散热主要分涡轮散热和涡扇散热两种,主要区别在于出风的方向。涡扇散热的出风口在显卡背板:<br><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkEeWLg0UUrtsztYiasVDzTXAn2SxuPxZMJJpypazcqu9ClqDg4Mp8C4qS3msZccQ0UToNyUd02djpA/0?wx_fmt=jpeg" alt=""><br>而涡扇散热的出风口在显卡的四周:<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkEeWLg0UUrtsztYiasVDzTXAXdYqgXaCg16lKakuakvSlUqEzenoH5Po4AUKHibzpBbicBqETcMMmwRw/0?wx_fmt=png" alt=""><br>如果机箱里有超过两张显卡,那么一定要选择涡轮散热,否则散热放出的热风会被其他显卡重新使用,造成散热失效。</p><p>目前公布的所有 RTX 3090 都是涡扇设计,所以并不适合多块显卡组装深度学习工作站。而公版 RTX 3090 简直是单卡的福音,多卡的噩梦,看看 founder edition 的风道设计:<br><img src="https://mmbiz.qpic.cn/mmbiz_png/OuEFiapfBFkEeWLg0UUrtsztYiasVDzTXAtF8FicIPSbeQHEzxOV8DNpghgDqD0xy4dUCwNpukaD814HaWJalwfLA/0?wx_fmt=png" alt=""><br>假如几块 RTX 3090 并联排列,头一块显卡的热风直接被第二块显卡利用,以此类推…</p><p>现在已经有板厂(貌似是 EVGA)计划推出双卡槽的水冷散热版 RTX 3090。不过四个水冷风扇也不好摆。</p><p>那两张 RTX 3090 能超过四张 RTX 2080 Ti 吗?Only time will tell.</p><hr><p>有传言说 NVIDIA 其实还有 20GB 显存的 RTX 3080 (个人猜测可能会被叫做 super)和 16GB 显存的 RTX 3070 (super)没有被发布,可能是在观望 AMD 的下一代显卡 Big Navi 的表现。双卡槽的 RTX 3080 如果有 20GB 显存无疑会解决以上的所有问题,成为深度学习 GPU 的理想工具。希望苏妈能够给力一点,让老黄早点发布 RTX 3080 super 和 RTX 3070 super。</p><p>现在没有合适的显卡,不代表以后也没有。反正我准备入一块 RTX 3070 先用着,等 20GB RTX 3080 super 出了直接换 4 卡哈哈哈。</p>]]></content>
<tags>
<tag> 深度学习工作站 </tag>
</tags>
</entry>
<entry>
<title>[工作站] 2020 年 10 月的单 GPU 深度学习工作站配置指南</title>
<link href="2020/09/20/%E5%B7%A5%E4%BD%9C%E7%AB%99-2020-%E5%B9%B4-10-%E6%9C%88%E7%9A%84%E5%8D%95-GPU-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E9%85%8D%E7%BD%AE%E6%8C%87%E5%8D%97/"/>
<url>2020/09/20/%E5%B7%A5%E4%BD%9C%E7%AB%99-2020-%E5%B9%B4-10-%E6%9C%88%E7%9A%84%E5%8D%95-GPU-%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0%E5%B7%A5%E4%BD%9C%E7%AB%99%E9%85%8D%E7%BD%AE%E6%8C%87%E5%8D%97/</url>
<content type="html"><![CDATA[<p>随着电脑硬件的性能提升、价格下降,搭建个人用深度学习工作站的支出越来越低,需求也会越来越大。因此从今年开始,每年的 5、10 月份均会发布最新的深度学习工作站的配置指南。</p><hr><p>随着 NVIDIA 的新一代 Ampere 架构的 GeFore 30 系列显卡的发布,在算力得到了极大提升的同时价格也大幅下降,花费不到 $1000 搭建一台性能强大的深度学习工作站已经成为了可能。适逢 AMD 的新一代 Ryzen 处理器也在 10 月 8 日发布,硬件性能的提升毫无疑问会再次推动深度学习的热潮。</p><p>我最近准备搭建自己的第一台深度学习工作站,本文(单 GPU 工作站)与下一篇文章(多 GPU 工作站)正是基于本人最近的研究。水平有限,没有实践,欢迎指正。</p><a id="more"></a><p>本文主要参考了以下文章:</p><ul><li><a href="https://timdettmers.com/2020/09/07/which-gpu-for-deep-learning/" target="_blank" rel="noopener" title="Which GPU(s) to Get for Deep Learning: My Experience and Advice for Using GPUs in Deep Learning">Which GPU(s) to Get for Deep Learning: My Experience and Advice for Using GPUs in Deep Learning</a></li><li><a href="https://timdettmers.com/2018/12/16/deep-learning-hardware-guide/" target="_blank" rel="noopener" title="A Full Hardware Guide to Deep Learning">A Full Hardware Guide to Deep Learning</a></li><li><a href="https://lambdalabs.com/blog/deep-learning-hardware-deep-dive-rtx-30xx/#blower-gpus" target="_blank" rel="noopener" title="Deep Learning Hardware Deep Dive – RTX 3090, RTX 3080, and RTX 3070">Deep Learning Hardware Deep Dive – RTX 3090, RTX 3080, and RTX 3070</a></li></ul><hr><p>工作站与个人游戏电脑不同,在配置上有一些需要改变的地方。对于深度学习来说,目前的唯一选择是 NVIDIA 的 GPU 产品;又因为本文的主题是个人深度学习工作站,所以本文仅涉及 NVIDIA 的 GeForce 系列消费级显卡(Tesla 以及 Quadro 系列都已经成为历史,统一到 GeForce 系列下)。本文首先来讨论深度学习工作站 must have 的部分,然后是 nice to have 的部分,再后是 don’t matter much 的部分,最后是 try to avoid 的部分。</p><h1 id="Must-Have"><a href="#Must-Have" class="headerlink" title="Must Have"></a>Must Have</h1><p>这部分不够就不行,但是超过也完全没用。</p><h2 id="显存"><a href="#显存" class="headerlink" title="显存"></a>显存</h2><p>通常来说,对显存的要求如下:</p><ul><li>研究 SOTA 模型:>= 11GB</li><li>一般的研究:8GB</li><li>Kaggle 及其它竞赛:4 - 8GB</li><li>公司业务:8GB 用于部署及原型测试,>= 11GB 用于训练</li></ul><p>对应到 RTX 30xx 系列显卡来说,可将 3060(6GB 显存),3060 Ti/3070 (8GB 显存)/3080(10GB 显存),3070 Super(16GB 显存)/3080 Super(20GB 显存)/3090(24GB 显存)对号入座。</p><p><strong>注</strong>: 3070 显卡将于 10 月 29 日上市,3060 Ti/2070 Super/3080 Super 预计在今年底前会陆续发布,3060 预计在明年年初发布。</p><h2 id="内存"><a href="#内存" class="headerlink" title="内存"></a>内存</h2><p>对于最大需要多少内存难以下定论,而 Tim Dettmers 说“额外的内存对特征工程非常有帮助”。综上,本人的推荐是 <em>内存容量 = 显存容量 + 6 ~ 8GB</em>。</p><h2 id="电源"><a href="#电源" class="headerlink" title="电源"></a>电源</h2><p>没人想在训练一半的时候因为供电不足而电脑重启,因此要预留足够的电源供电。主机内耗电的部分主要为 GPU、CPU 和主板上的其它部件。通过研究 GPU 与 CPU 的耗电数据,我发现 GPU 的峰值功耗要超过 TDP 100w 左右,而 8 核以下的 CPU 的峰值功耗大概可以归纳为</p><table><thead><tr><th>核心数量</th><th>峰值功耗(w)</th></tr></thead><tbody><tr><td>4</td><td>100</td></tr><tr><td>6</td><td>150</td></tr><tr><td>8</td><td>200</td></tr></tbody></table><p>主板的功耗(内存和硬盘之类的总和)大概为 80w,故电源功率的最低要求为:CPU 峰值功耗 + GPU 峰值功耗 + 80。因为 CPU 和 GPU 很少同时满负荷工作,因此不需要考虑冗余电源。比如 RTX 3080 的 TDP 为 320w,故一台 Ryzen 5 3600 与 一张 RTX 3080 的工作站需要一个额定功率最少为 150 + 320 + 100 + 80 = 650w 的电源。</p><p>另外不像游戏主机在不运行的时候关闭,工作站一般是 7 * 24 小时开机的,所以电源的转换效率也很重要。以下为 80 Plus 认证在 115V 电压下 100% 负载时的转换效率表<a href="https://en.wikipedia.org/wiki/80_Plus" target="_blank" rel="noopener" title="80 Plus 介绍"></a>:</p><table><thead><tr><th>认证等级</th><th>利用率</th></tr></thead><tbody><tr><td>White</td><td>80%</td></tr><tr><td>Bronze</td><td>82%</td></tr><tr><td>Silver</td><td>85%</td></tr><tr><td>Gold</td><td>87%</td></tr><tr><td>Platinum</td><td>89%</td></tr><tr><td>Titanium</td><td>90%</td></tr></tbody></table><p>一般来说,功耗在 600w 以下 Bronze 就可以了,600w ~ 1000w 之间推荐 Gold,1000w 以上推荐 Platinum 或 Titanium。</p><h1 id="Nice-to-Have"><a href="#Nice-to-Have" class="headerlink" title="Nice to Have"></a>Nice to Have</h1><p>以上的因素决定了模型能不能训练,下面的因素决定了训练模型的速度和操作者的体验。</p><h2 id="Tensor-Core"><a href="#Tensor-Core" class="headerlink" title="Tensor Core"></a>Tensor Core</h2><p>Tensor Core 可以极大地加快矩阵乘法,深度学习优先使用 Tensor Core 进行训练。由于 RTX 架构的 Tensor Core 可以以半精度(16bit)进行训练,显存需求减半,所以相比 GTX 显卡在同样的显存下可以训练大一倍的模型,因此除非预算极度有限,应该优先考虑 RTX 20/30 系列显卡。一张显卡有多少 Tensor Core 决定了这张显卡的算力,而 Tensor FLOPS 则量化了显卡的算力<a href="https://en.wikipedia.org/wiki/GeForce_20_series" target="_blank" rel="noopener" title="GeForce 20 series 系列介绍"></a><a href="https://www.reddit.com/r/buildapcsales/comments/ikqulm/meta_information_thread_for_the_nvidia_30series/" target="_blank" rel="noopener" title="[Meta] Information Thread for the Nvidia 30-series GPUs launch"></a>。</p><table><thead><tr><th>芯片型号</th><th>Tensor Core</th><th>Tensor FLOPS (万亿)</th><th>显存(GB)</th><th>TDP (W)</th><th>MSRP (USD)</th></tr></thead><tbody><tr><td>2060</td><td>240</td><td>51.6</td><td>6</td><td>160</td><td>349</td></tr><tr><td>2060 super</td><td>272</td><td>57.4</td><td>8</td><td>175</td><td>399</td></tr><tr><td>2070 super</td><td>320</td><td>72.5</td><td>8</td><td>215</td><td>499</td></tr><tr><td>2080 super</td><td>384</td><td>89.2</td><td>8</td><td>250</td><td>699</td></tr><tr><td>2080 Ti</td><td>544</td><td>107.6</td><td>11</td><td>250</td><td>999</td></tr><tr><td>Titan RTX</td><td>576</td><td>130.5</td><td>24</td><td>280</td><td>2499</td></tr><tr><td>3070</td><td>184</td><td>163</td><td>8</td><td>220</td><td>499</td></tr><tr><td>3080</td><td>272</td><td>238</td><td>10</td><td>320</td><td>699</td></tr><tr><td>3090</td><td>328</td><td>285</td><td>24</td><td>350</td><td>1499</td></tr></tbody></table><p>虽然 30 系列的 Tensor Core 数量比 20 系列少, 但官方称 30 系列的 Tensor Core 的性能是 20 系列的 4 倍,所以(如果官方宣传为真的话)3070 的实际算力要强于 2080 Ti。3060 Ti 尚未被官方确认,但估计其算力应该与 2080 Ti 相当。</p><p><strong>购买建议</strong>:消息说 3060 Ti 的官方指导价格是 $349。在 3060 Ti 存在的前提下 3070 是比较尴尬的存在,比上不足,比下有余。而 20 系列显卡毫无性价比,除非预算有限买二手 2060/2060 Super,否则不推荐。我的购买建议是:</p><ul><li>预算严重不足的入门菜鸟:二手 1660 Super</li><li>预算不足的入门菜鸟:二手 2060/2060 Super/2070</li><li>有一点预算的入门菜鸟:3060 Ti</li><li>中阶使用者:3070 Super/3080 Super</li><li>高阶使用者:3090</li></ul><p>现在避免购买任何 2080/2080 Super/2080 Ti 显卡(包括二手显卡)。</p><h2 id="外设"><a href="#外设" class="headerlink" title="外设"></a>外设</h2><p>除了 GPU,1 到 2 台额外的显示器和一把趁手的键盘可能是最有价值的投资。不过这部分比较主观,如何选择由各位读者考虑。</p><p><strong>购买建议</strong>:</p><ul><li>购买带翻转屏功能的显示器;我现在用的两台显示器中有一台 Dell 2718Q。</li><li>购买高分辨率的显示器;我现在用的是两台 27” 4K 显示器。</li><li>选择机械键盘。轴体根据自己的喜好选择,我现在用的是茶轴。</li></ul><p><img src="https://mmbiz.qpic.cn/mmbiz_jpg/OuEFiapfBFkEVjIE1DBlj7MuodfxytVDQ1mM3G5zCqM1rFhuFI3Ex7f3d9hnsva4w7AnxofoGWhhhnN57moeVSQ/0?wx_fmt=jpeg" alt="二次元真香"></p><h1 id="Don’t-Matter-Much"><a href="#Don’t-Matter-Much" class="headerlink" title="Don’t Matter Much"></a>Don’t Matter Much</h1><p>这部分对性能提升非常有限,不如节省下来减少开支。</p><h2 id="CPU-的核心数量和主频"><a href="#CPU-的核心数量和主频" class="headerlink" title="CPU 的核心数量和主频"></a>CPU 的核心数量和主频</h2><p>在深度学习中,CPU 的主要工作是数据预处理。有两种策略:</p><p><em>Loop</em>:</p><ul><li>Load mini-batch</li><li>Preprocessing mini-batch</li><li>Train on mini-batch</li></ul><p>或者</p><ul><li>Preprocess data</li><li>Loop:<ul><li>Load preprocessed mini-batch</li><li>Train on mini-batch</li></ul></li></ul><p>对于第一种策略,一颗强大的 CPU 会显著提高性能,推荐为 GPU 配备至少 4 个 CPU 线程;而第二种策略通常不需要非常好的 CPU,2 个线程足够。所以对于单 GPU 工作站而言,最低端的 Core i3 10100F 或者 Ryzen 3 3100 已经足够(两者都是 4 核心 8 线程),6 核以上完全没有必要。</p><p>而对于 CPU 频率而言,频率的影响非常有限(因为 CPU 在深度学习中不起主导作用),主频从 1.1GHz 提升到 3.6GHz 的综合性能提升在 4% ~ 8% 之间。</p><h2 id="PCI-E-等级-amp-通道数"><a href="#PCI-E-等级-amp-通道数" class="headerlink" title="PCI-E 等级 & 通道数"></a>PCI-E 等级 & 通道数</h2><p>PCI-E(Peripheral Component Interconnect Express)总线在 2003 年推出,取代了曾经的 PCI 和 AGP 总线,目前在使用的标准为 PCI-E 3.0 和 PCI-E 4.0。PCI-E 总线是一种串行总线,单个插槽上可以有 1、2、4、8、16 条通道,带宽如下:</p><table><thead><tr><th>PCI-E 版本</th><th>x1</th><th>x2</th><th>x4</th><th>x8</th><th>x16</th></tr></thead><tbody><tr><td>3.0</td><td>1.97GB/s</td><td>7.88GB/s</td><td>15.75GB/s</td><td>31.5GB/s</td><td></td></tr><tr><td>4.0</td><td>3.98GB/s</td><td>15.75GB/s</td><td>31.51GB/s</td><td>63GB/s</td><td></td></tr></tbody></table><p>可以看到,PCI-E 4.0 的带宽是 PCI-E 3.0 的两倍,因为 AMD 的 X570 和 B550 芯片组支持 PCI-E 4.0,而 Intel 要到明年上半年的 11 代 Core 才支持,所以使用 Ryzen 3 代以后的 CPU 和 500 系列主板会有带宽的优势。</p><p>通常显卡占用 8 或 16 条 PCI-E 通道,而一块 NVME M.2 存储器占用 4 条 PCI-E 通道。虽然显卡接收与传递数据经过 PCI-E 总线,然而在仅有 1 张显卡的时候,PCI-E 总线的带宽与级别对显卡性能的影响并不大,PCI-E 的带宽对文件的读取/写入性能的影响更大一点。不过在一个 pipeline 里面数据一般仅仅读取/写入一次,因此 PCI-E 4.0 或者 3.0 对性能影响有限。</p><h2 id="内存频率与延迟"><a href="#内存频率与延迟" class="headerlink" title="内存频率与延迟"></a>内存频率与延迟</h2><p>同上面一条,因为数据在 GPU 与 CPU 之间的交互次数有限,故速度更快、延迟更低的内存对性能提升有限。</p><h2 id="散热"><a href="#散热" class="headerlink" title="散热"></a>散热</h2><p>对于一台仅有 1 张显卡的工作站而言,散热不是需要考虑的问题。</p><h1 id="Try-to-Avoid"><a href="#Try-to-Avoid" class="headerlink" title="Try to Avoid"></a>Try to Avoid</h1><h2 id="超频"><a href="#超频" class="headerlink" title="超频"></a>超频</h2><p>对于长时间运行的工作站来说,超频会减少原件的寿命,降低系统的稳定性,增加功耗,因此超频是大忌。不要购买任何出厂预超频( Overclock 或 OC 版)的显卡或自己超频。</p><h2 id="灯光效果"><a href="#灯光效果" class="headerlink" title="灯光效果"></a>灯光效果</h2><p>不是说不能有光效,但是工作站是用来干活的,不是用来欣赏的,而且工作站一般放在不起眼的地方,有光效也看不见。看得见的光效除了分散注意力以外还耗费额外的电能(还费钱),实在没有意义。</p>]]></content>
<tags>
<tag> 深度学习工作站 </tag>
</tags>
</entry>
<entry>
<title>[NLP] 新手的第一个 NLP 项目:文本分类(4)</title>
<link href="2020/08/23/NLP-%E6%96%B0%E6%89%8B%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA-NLP-%E9%A1%B9%E7%9B%AE%EF%BC%9A%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%EF%BC%884%EF%BC%89/"/>
<url>2020/08/23/NLP-%E6%96%B0%E6%89%8B%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA-NLP-%E9%A1%B9%E7%9B%AE%EF%BC%9A%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%EF%BC%884%EF%BC%89/</url>
<content type="html"><![CDATA[<p>在之前的文章中,我们使用了 CNN 和 RNN 对 IMDB 数据集进行了分析,10 个 epoch 以后准确率不到 85%。除了使用更复杂的模型以外,我们还可以使用更好的词向量。本文中我们将使用 Bert 词向量和 GRU 层搭建另一个简单的神经网络模型。由于 transformers 涉及到大量计算,本文中将使用 Google Colab 提供的 GPU。</p><p>与前面的数据预处理流程不同,这里我们将使用 <code>torchtext</code> 来封装数据。有关 <code>torchtext</code> 的知识请看 <a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485080&idx=1&sn=aea4f1162268db972506f9df6046c9c1&chksm=cf06b7a4f8713eb214b2ffe39e3f792e9dc0990ff691a97fafae1f7361e300d5e6188da5fafe&token=1227810784&lang=zh_CN#rd" target="_blank" rel="noopener">PyTorch 折桂 13:TorchText</a>。</p><a id="more"></a><p>安装所需的包:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">!pip install -U torch <span class="comment"># 1.7</span></span><br><span class="line">!pip install -U torchtext <span class="comment"># 0.7</span></span><br><span class="line">!pip install -U transformers <span class="comment"># 3.0.2</span></span><br></pre></td></tr></table></figure><p>设置随机种子:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> random</span><br><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"></span><br><span class="line">SEED = <span class="number">1988</span></span><br><span class="line"></span><br><span class="line">random.seed(SEED)</span><br><span class="line">np.random.seed(SEED)</span><br><span class="line">torch.manual_seed(SEED)</span><br><span class="line">torch.cuda.manual_seed(SEED)</span><br><span class="line">torch.backends.cudnn.deterministic = <span class="literal">True</span> <span class="comment"># 这样可以稍微增加训练的速度</span></span><br></pre></td></tr></table></figure><h1 id="数据准备"><a href="#数据准备" class="headerlink" title="数据准备"></a>数据准备</h1><p>之前的文章中,我们仅仅使用了 <code><PAD></code> 来填充不足的空位;而在 Bert 里,除了 <code><PAD></code> 还使用了 <code>BOS</code> 和 <code><EOS></code> 来表示句子的开始和结束以及 <code><UNK></code> 来表示单词表以外的单词。另外 Bert 取每句话前 512 个单词。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">init_token_id = tokenizer.cls_token_id <span class="comment"># BOS</span></span><br><span class="line">eos_token_id = tokenizer.sep_token_id <span class="comment"># EOS</span></span><br><span class="line">pad_token_id = tokenizer.pad_token_id <span class="comment"># PAD</span></span><br><span class="line">unk_token_id = tokenizer.unk_token_id <span class="comment"># UNK</span></span><br><span class="line"></span><br><span class="line">max_length_input = tokenizer.max_model_input_sizes[<span class="string">'bert-base-uncased'</span>]</span><br></pre></td></tr></table></figure><p>我们载入预训练好的 Bert 分词器并以此构建分词函数。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> transformers <span class="keyword">import</span> BertTokenizer</span><br><span class="line">tokenizer = BertTokenizer.from_pretrained(<span class="string">'bert-base-uncased'</span>)</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">tokenize_and_cut</span><span class="params">(sentence)</span>:</span></span><br><span class="line"> tokens = tokenizer.tokenize(sentence)</span><br><span class="line"> tokens = tokens[:max_length_input - <span class="number">2</span>]</span><br><span class="line"> <span class="keyword">return</span> tokens</span><br></pre></td></tr></table></figure><p>下一步是构建数据集的域。所谓“域”指的是数据集里对文本与标签的处理方式的声明。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torchtext.data <span class="keyword">import</span> Field, LabelField</span><br><span class="line"></span><br><span class="line">TEXT = Field(batch_first=<span class="literal">True</span>,</span><br><span class="line"> use_vocab=<span class="literal">False</span>,</span><br><span class="line"> tokenize=tokenize_and_cut,</span><br><span class="line"> preprocessing=tokenizer.convert_tokens_to_ids,</span><br><span class="line"> init_token=init_token_id,</span><br><span class="line"> eos_token=eos_token_id,</span><br><span class="line"> pad_token=pad_token_id,</span><br><span class="line"> unk_token=unk_token_id)</span><br><span class="line"></span><br><span class="line">LABEL = LabelField(dtype=torch.float)</span><br></pre></td></tr></table></figure><p>最后就是读取与封装数据。<code>batch size</code> 设置为 64。因为使用 GPU 训练,数据需要转移到 GPU 上。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torchtext <span class="keyword">import</span> datasets</span><br><span class="line"></span><br><span class="line">train, test = datasets.IMDB.splits(TEXT, LABEL)</span><br><span class="line">LABEL.build_vocab(train)</span><br><span class="line"><span class="keyword">from</span> torchtext.data <span class="keyword">import</span> BucketIterator</span><br><span class="line"></span><br><span class="line">BATCH_SIZE = <span class="number">64</span></span><br><span class="line"></span><br><span class="line">device = torch.device(<span class="string">'cuda'</span> <span class="keyword">if</span> torch.cuda.is_available() <span class="keyword">else</span> <span class="string">'cpu'</span>)</span><br><span class="line"></span><br><span class="line">train_iter, test_iter = BucketIterator.splits(</span><br><span class="line"> (train, test),</span><br><span class="line"> batch_size=BATCH_SIZE,</span><br><span class="line"> device=device</span><br><span class="line">)</span><br></pre></td></tr></table></figure><h1 id="模型搭建"><a href="#模型搭建" class="headerlink" title="模型搭建"></a>模型搭建</h1><p>载入 Bert 预训练模型:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> transformers <span class="keyword">import</span> BertTokenizer, BertModel</span><br><span class="line">bert = BertModel.from_pretrained(<span class="string">'bert-base-uncased'</span>)</span><br></pre></td></tr></table></figure><p>根据 <a href="https://arxiv.org/pdf/1810.04805.pdf" target="_blank" rel="noopener" title="Bert 论文">Bert 论文</a>,Bert base 模型的超参数有:transformers 层数为 12,隐藏层维度为 768,self-attention head 数量为 12。我们在实际模型中只需要隐藏层维度。现在我们搭建一个在 Bert 后面连接一个双层、双向 GRU 的模型。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">BertGRU</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, bert, hidden_dim, n_layers, bidirectional, dropout)</span>:</span></span><br><span class="line"> super().__init__()</span><br><span class="line"></span><br><span class="line"> self.bert = bert</span><br><span class="line"></span><br><span class="line"> embed_dim = bert.config.to_dict()[<span class="string">'hidden_size'</span>]</span><br><span class="line"></span><br><span class="line"> self.gru = nn.GRU(embed_dim, hidden_dim, num_layers=n_layers, bidirectional=bidirectional, </span><br><span class="line"> batch_first=<span class="literal">True</span>, dropout=<span class="number">0</span> <span class="keyword">if</span> n_layers < <span class="number">2</span> <span class="keyword">else</span> dropout)</span><br><span class="line"> </span><br><span class="line"> self.fc = nn.Linear(hidden_dim * <span class="number">2</span> <span class="keyword">if</span> bidirectional <span class="keyword">else</span> hidden_dim, <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> self.dropout = nn.Dropout(dropout)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, text)</span>:</span> <span class="comment"># text: [BATCH_SIZE, SEQ_LENGTH]</span></span><br><span class="line"> <span class="keyword">with</span> torch.no_grad():</span><br><span class="line"> embedded = self.bert(text)[<span class="number">0</span>] <span class="comment"># embedded: [BATCH_SIZE, SEQ_LENGTH, EMBED_DIM]</span></span><br><span class="line"></span><br><span class="line"> _, hidden = self.gru(embedded) <span class="comment"># hidden: [N_LAYERS * n_driections, BATCH_SIZE, EMBED_DIM]</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">if</span> self.gru.bidirectional:</span><br><span class="line"> hidden = self.dropout(torch.cat((hidden[<span class="number">-2</span>, :, :], hidden[<span class="number">-1</span>, :, :]), dim=<span class="number">1</span>))</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> hidden = self.dropout(hidden[<span class="number">-1</span>, :, :])</span><br><span class="line"></span><br><span class="line"> output = self.fc(hidden) <span class="comment"># hidden: [BATCH_SIZE, 1]</span></span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> output</span><br></pre></td></tr></table></figure><p>首先实例化这个模型。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">HIDDEN_DIM = <span class="number">768</span></span><br><span class="line">N_LAYERS = <span class="number">2</span></span><br><span class="line">BIDIRECTIONAL = <span class="literal">True</span></span><br><span class="line">DROPOUT = <span class="number">0.5</span></span><br><span class="line"></span><br><span class="line">model = BertGRU(bert, HIDDEN_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT)</span><br></pre></td></tr></table></figure><p>因为 Bert 是已经训练好的词向量,我们不希望它被训练,也不希望它的权重被更新,所以模型里有 <code>with torch.no_grad()</code> 代码块。另外我们也手动关闭 Bert 有关的权重更新:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> name, param <span class="keyword">in</span> model.named_parameters():</span><br><span class="line"> <span class="keyword">if</span> name.startswith(<span class="string">'bert'</span>):</span><br><span class="line"> param.requires_grad = <span class="literal">False</span></span><br></pre></td></tr></table></figure><p>优化器和损失函数和前面一样,使用 Adam 和二分类交叉熵。同样将优化器和损失函数转移到 GPU 上。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> optim</span><br><span class="line"></span><br><span class="line">optimizer = optim.Adam(model.parameters())</span><br><span class="line"></span><br><span class="line">criterion = nn.BCEWithLogitsLoss()</span><br><span class="line"></span><br><span class="line">model = model.to(device)</span><br><span class="line">criterion = criterion.to(device)</span><br></pre></td></tr></table></figure><p>后面的训练和预测同以前的文章一样,不再赘述。训练 10 个 epoch 后的表现为:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">Epoch: <span class="number">10</span> | Epoch Time: <span class="number">38</span>m <span class="number">7</span>s</span><br><span class="line">Train Loss: <span class="number">0.094</span> | Train Acc: <span class="number">96.62</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.243</span> | Val. Acc: <span class="number">92.39</span>%</span><br></pre></td></tr></table></figure><p>有了 Bert 的加持,模型的性能提高了约 10%。</p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
<tag> NLP </tag>
</tags>
</entry>
<entry>
<title>[NLP] 新手的第一个 NLP 项目:文本分类(3)</title>
<link href="2020/08/18/NLP-%E6%96%B0%E6%89%8B%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA-NLP-%E9%A1%B9%E7%9B%AE%EF%BC%9A%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%EF%BC%883%EF%BC%89/"/>
<url>2020/08/18/NLP-%E6%96%B0%E6%89%8B%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA-NLP-%E9%A1%B9%E7%9B%AE%EF%BC%9A%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%EF%BC%883%EF%BC%89/</url>
<content type="html"><![CDATA[<h1 id="前文回顾"><a href="#前文回顾" class="headerlink" title="前文回顾"></a>前文回顾</h1><p>在前两篇文章<a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485105&idx=1&sn=0fac0adff0dff8812d73a0e510d65e9a&chksm=cf06b78df8713e9b9a4024c7f09d2b031afce20e43dc696006be9bf9c34b6f4bf25f4e71acae&token=1910822964&lang=zh_CN#rd" target="_blank" rel="noopener">新手的第一个 NLP 任务:文本分类(1)</a>和<a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485108&idx=1&sn=62ec5a9782c5e7dbecd6cac8f7e0eb58&chksm=cf06b788f8713e9e19c8513a687e92e782720f3e851f1219e8f956e02c8f69c1bc5838606cc1&token=134245917&lang=zh_CN#rd" target="_blank" rel="noopener">新手的第一个 NLP 项目:文本分类(2)</a>中,我们读取了数据、对数据进行了预处理和封装,并搭建了一个 CNN 模型。本文中,我们将 CNN 模型换为 RNN 模型。</p><h1 id="数据的准备"><a href="#数据的准备" class="headerlink" title="数据的准备"></a>数据的准备</h1><p>同<a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485105&idx=1&sn=0fac0adff0dff8812d73a0e510d65e9a&chksm=cf06b78df8713e9b9a4024c7f09d2b031afce20e43dc696006be9bf9c34b6f4bf25f4e71acae&token=1910822964&lang=zh_CN#rd" target="_blank" rel="noopener">新手的第一个 NLP 任务:文本分类(1)</a>一样,不再赘述。</p><a id="more"></a><h1 id="基础-RNN-模型"><a href="#基础-RNN-模型" class="headerlink" title="基础 RNN 模型"></a>基础 RNN 模型</h1><p>有关 RNN 的知识可以参考我以前写的文章 <a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247485071&idx=1&sn=b9b570591e340c53b8d5cc161ef382cc&chksm=cf06b7b3f8713ea5f4b689a5c018ef21f5f1e42f32d2f4965acfec589d1260c21e9ceb90cee3&token=129126518&lang=zh_CN#rd" target="_blank" rel="noopener">PyTorch 折桂 11:CNN & RNN</a>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn, optim</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> functional <span class="keyword">as</span> F</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">RNN</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, vocab_size, embed_dim, hidden_dim)</span>:</span></span><br><span class="line"> super(RNN, self).__init__()</span><br><span class="line"> </span><br><span class="line"> self.embedding = nn.Embedding(vocab_size, embed_dim) <span class="comment"># (BATCH_SIZE, SEQ_LEN, EMBED_DIM)</span></span><br><span class="line"> self.rnn = nn.RNN(embed_dim, hidden_dim, batch_first=<span class="literal">True</span>)</span><br><span class="line"> self.fc = nn.Linear(hidden_dim, <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> x = self.embedding(x)</span><br><span class="line"></span><br><span class="line"> output, hidden = self.rnn(x)</span><br><span class="line"> <span class="comment"># output: (BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM)</span></span><br><span class="line"> <span class="comment"># hidden: (1, BATCH_SIZE, HIDDEN_DIM)</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> self.fc(hidden.squeeze(<span class="number">0</span>))</span><br></pre></td></tr></table></figure><p>我们首先使用一层单向 RNN。RNN 网络生成两个张量:输出层与保存了历史信息的隐藏层。使用哪一个呢?这要具体问题具体分析。对于文本摘要类任务,一般使用保存了历史信息的隐藏层。</p><p>这里要注意隐藏层的维度:当 <code>batch_first=True</code> 时,隐藏层的维度为(<code>(num_layers * directions, BATCH_SIZE, HIDDEN_DIM)</code>);当 <code>batch_first=False</code> 时,隐藏层的维度为(<code>(num_layers * directions, SEQ_LENGTH, HIDDEN_DIM)</code>)。</p><p>因为这是一个单词单向的 RNN,所以第 0 维为 1;在将隐藏层进行全连接处理以前,先去除无用的第 0 维。</p><p>实例化 RNN 网络:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">EMBED_DIM = <span class="number">128</span></span><br><span class="line">HIDDEN_DIM = <span class="number">256</span></span><br><span class="line">rnn = RNN(len(vocab), EMBED_DIM, HIDDEN_DIM)</span><br></pre></td></tr></table></figure><p>损失函数、优化器、训练过程与前文一致,不再赘述。训练 10 个 epoch 后的结果如下:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">Epoch: <span class="number">10</span> | Epoch Time: <span class="number">1</span>m <span class="number">8</span>s</span><br><span class="line">Train Loss: <span class="number">0.590</span> | Train Acc: <span class="number">68.58</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.682</span> | Val. Acc: <span class="number">61.32</span>%</span><br></pre></td></tr></table></figure><p>可以看到,模型过拟合了。下面我们改进一下这个 RNN 模型。</p><h1 id="改进-RNN-模型"><a href="#改进-RNN-模型" class="headerlink" title="改进 RNN 模型"></a>改进 RNN 模型</h1><p>我们主要从以下两个方面进行改进:</p><ol><li>改进词嵌入;</li><li>增加模型的复杂度(使用两层双向 LSTM);</li><li>增加正则化。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">LSTM</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, vocab_size, embedding_dim, hidden_dim, n_layers, </span></span></span><br><span class="line"><span class="function"><span class="params"> bidirectional, dropout)</span>:</span></span><br><span class="line"> super(LSTM, self).__init__()</span><br><span class="line"> </span><br><span class="line"> self.embed = nn.Embedding(vocab_size, embedding_dim, padding_idx=<span class="number">0</span>)</span><br><span class="line"> </span><br><span class="line"> self.lstm = nn.LSTM(embedding_dim, </span><br><span class="line"> hidden_dim, </span><br><span class="line"> num_layers=n_layers, </span><br><span class="line"> bidirectional=bidirectional, </span><br><span class="line"> dropout=dropout,</span><br><span class="line"> batch_first=<span class="literal">True</span>)</span><br><span class="line"> </span><br><span class="line"> self.dropout = nn.Dropout(dropout)</span><br><span class="line"> self.num_directions = <span class="number">2</span> <span class="keyword">if</span> bidirectional <span class="keyword">else</span> <span class="number">1</span></span><br><span class="line"> self.fc = nn.Linear(hidden_dim * self.num_directions, <span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> embedded = self.dropout(self.embed(x)) <span class="comment"># (BATCH_SIZE, SEQ_LEN, EMBED_DIM)</span></span><br><span class="line"></span><br><span class="line"> output, (hidden, cell) = self.lstm(embedded)</span><br><span class="line"> <span class="comment"># output: (BATCH_SIZE, SEQ_LENGTH, HIDDEN_DIM)</span></span><br><span class="line"> <span class="comment"># hidden: (n_layers * num_directions, BATCH_SIZE, HIDDEN_DIM)</span></span><br><span class="line"> <span class="comment"># cell: (n_layers * num_directions, BATCH_SIZE, HIDDEN_DIM)</span></span><br><span class="line"> </span><br><span class="line"> hidden = self.dropout(torch.cat((hidden[<span class="number">-2</span>, :, :], hidden[<span class="number">-1</span>, :, :]), dim=<span class="number">1</span>))</span><br><span class="line"> <span class="comment"># hidden: (BATCH_SIZE, HIDDEN_DIM * 2)</span></span><br><span class="line"> <span class="keyword">return</span> self.fc(hidden)</span><br></pre></td></tr></table></figure>首先,填充 <code><PAD></code> 应该恒为 0,所以我们在词嵌入层中加入 <code>padding_idx=0</code> 条件。这里 <code>padding_idx</code> 为 0 是因为我们在数据准备过程中将填充占位设为 0。</li></ol><p>其次,将 RNN 层变成 LSTM 层。LSTM 模型的输出有三个,<code>output, (hidden, cell)</code>,隐藏层与细胞状态在一个元组内。当 <code>batch_first=True</code> 时,隐藏层与细胞状态的维度为(<code>(num_layers * directions, BATCH_SIZE, HIDDEN_DIM)</code>);当 <code>batch_first=False</code> 时,隐藏层与细胞状态的维度为(<code>(num_layers * directions, SEQ_LENGTH, HIDDEN_DIM)</code>)。当方向为双向且层数多于 1 时,隐藏层与细胞状态的堆叠层次为:$[第一层正向,第一层反向,…,最后一层正向,最后一层反向]$。这里使用了两层双向 LSTM。我们需要最后一层的正向与反向隐藏层,并把它们拼接在一起。</p><p>最后,还加入了 dropout 正则化。LSTM 内部的 dropout 可以使用 <code>dropout</code> 声明,LSTM 与全连接层之间的 dropout 可以使用 <code>nn.Dropout</code> 层。</p><p>实例化这个 LSTM 模型。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">EMBED_DIM = <span class="number">128</span></span><br><span class="line">HIDDEN_DIM = <span class="number">256</span></span><br><span class="line">N_LAYERS = <span class="number">2</span></span><br><span class="line">BIDIRECTIONAL = <span class="literal">True</span></span><br><span class="line">DROPOUT = <span class="number">0.5</span></span><br><span class="line"></span><br><span class="line">lstm = LSTM(len(vocab), EMBED_DIM, HIDDEN_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT)</span><br></pre></td></tr></table></figure><p>损失函数、优化器、训练过程与前面相同。最终的训练效果为:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">Epoch: <span class="number">10</span> | Epoch Time: <span class="number">7</span>m <span class="number">34</span>s</span><br><span class="line">Train Loss: <span class="number">0.303</span> | Train Acc: <span class="number">87.30</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.412</span> | Val. Acc: <span class="number">83.43</span>%</span><br></pre></td></tr></table></figure><p>比前面的 CNN 效果稍好。下文中我们将使用 SOTA 的预训练模型 - BERT。</p><p>本文的代码可以在 <a href="https://github.com/vincent507cpu/nlp\_project/blob/master/text%20classification/02%20RNN.ipynb" target="_blank" rel="noopener">https://github.com/vincent507cpu/nlp\_project/blob/master/text%20classification/02%20RNN.ipynb</a> 查看。</p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
<tag> NLP </tag>
</tags>
</entry>
<entry>
<title>[NLP] 新手的第一个 NLP 项目:文本分类(2)</title>
<link href="2020/08/07/NLP-%E6%96%B0%E6%89%8B%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA-NLP-%E9%A1%B9%E7%9B%AE%EF%BC%9A%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%EF%BC%882%EF%BC%89/"/>
<url>2020/08/07/NLP-%E6%96%B0%E6%89%8B%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA-NLP-%E9%A1%B9%E7%9B%AE%EF%BC%9A%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%EF%BC%882%EF%BC%89/</url>
<content type="html"><![CDATA[<p>现在数据已经准备就绪,可以构建模型了。</p><p>本文的模型参考了论文 <a href="https://arxiv.org/pdf/1408.5882.pdf" target="_blank" rel="noopener" title="Convolutional Neural Networks for Sentence Classification">《Convolutional Neural Networks for Sentence Classification》</a>,原文代码<a href="https://github.com/dennybritz/cnn-text-classification-tf" target="_blank" rel="noopener" title="论文代码">在此</a>。</p><p>论文里使用了两个词嵌入:随模型进行训练的词嵌入和 Google 预训练好的 Word2Vec 词嵌入。本文里为了直观,没有采用预训练的词嵌入。</p><h1 id="构建模型"><a href="#构建模型" class="headerlink" title="构建模型"></a>构建模型</h1><p>论文里使用了三个卷积核分别为 3、4、5 的二维卷积层,拼接后经过一个范围为 4 的池化层。最后经过一个全连接层,经过 sigmoid 函数处理后输出。</p><a id="more"></a><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn</span><br><span class="line"><span class="keyword">from</span> torch.nn <span class="keyword">import</span> functional <span class="keyword">as</span> F</span><br><span class="line"></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">CNN</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, vocab_size, embed_size, dropout, batch_size)</span>:</span></span><br><span class="line"> super(CNN, self).__init__()</span><br><span class="line"> self.batch_size = batch_size</span><br><span class="line"> </span><br><span class="line"> self.embedding = nn.Embedding(vocab_size, embed_size) <span class="comment"># (BATCH_SIZE, SEQ_LEN, embed_size)</span></span><br><span class="line"> </span><br><span class="line"> self.conv1 = nn.Conv2d(<span class="number">1</span>, <span class="number">1</span>, <span class="number">3</span>)</span><br><span class="line"> self.conv2 = nn.Conv2d(<span class="number">1</span>, <span class="number">1</span>, <span class="number">4</span>)</span><br><span class="line"> self.conv3 = nn.Conv2d(<span class="number">1</span>, <span class="number">1</span>, <span class="number">5</span>)</span><br><span class="line"></span><br><span class="line"> self.dropout = nn.Dropout(dropout)</span><br><span class="line"> </span><br><span class="line"> self.fc = nn.Linear(<span class="number">2232</span>, <span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> x = self.embedding(x)</span><br><span class="line"> x.unsqueeze_(<span class="number">1</span>) <span class="comment"># (BATCH_SIZE, 1, SEQ_LEN, embed_size)</span></span><br><span class="line"> output1 = self.conv1(x)</span><br><span class="line"> output1 = F.max_pool2d(F.relu(output1), <span class="number">4</span>)</span><br><span class="line"> </span><br><span class="line"> output2 = self.conv2(x)</span><br><span class="line"> output2 = F.max_pool2d(F.relu(output2), <span class="number">4</span>)</span><br><span class="line"> </span><br><span class="line"> output3 = self.conv3(x)</span><br><span class="line"> output3 = F.max_pool2d(F.relu(output3), <span class="number">4</span>)</span><br><span class="line"> output = torch.cat([output1, output2, output3], axis=<span class="number">1</span>)</span><br><span class="line"> output = self.dropout(output)</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> self.fc(output.view(self.batch_size, <span class="number">-1</span>))</span><br></pre></td></tr></table></figure><h5 id="注意:因为经过词嵌入的张量维度为-BATCH-SIZE-SEQ-LEN-embed-size-,而-nn-Conv2d-的输入张量的维度要求为-BATCH-SIZE-CHANNEL-NONE-NONE-,所以我们需要使用-x-unsqueeze-1-为张量添加一个维度。"><a href="#注意:因为经过词嵌入的张量维度为-BATCH-SIZE-SEQ-LEN-embed-size-,而-nn-Conv2d-的输入张量的维度要求为-BATCH-SIZE-CHANNEL-NONE-NONE-,所以我们需要使用-x-unsqueeze-1-为张量添加一个维度。" class="headerlink" title="注意:因为经过词嵌入的张量维度为 (BATCH_SIZE, SEQ_LEN, embed_size),而 nn.Conv2d 的输入张量的维度要求为 (BATCH_SIZE, CHANNEL, NONE, NONE),所以我们需要使用 x.unsqueeze_(1) 为张量添加一个维度。"></a>注意:因为经过词嵌入的张量维度为 <code>(BATCH_SIZE, SEQ_LEN, embed_size)</code>,而 <code>nn.Conv2d</code> 的输入张量的维度要求为 <code>(BATCH_SIZE, CHANNEL, NONE, NONE)</code>,所以我们需要使用 <code>x.unsqueeze_(1)</code> 为张量添加一个维度。</h5><p>我们使用 Adam 为优化器,<code>nn.BCEWithLogitsLoss()</code> 为损失函数。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> optim</span><br><span class="line"></span><br><span class="line">optimizer = optim.Adam(model.parameters())</span><br><span class="line">criterion = nn.BCEWithLogitsLoss()</span><br></pre></td></tr></table></figure><h5 id="注意:nn-BCEWithLogitsLoss-是先进行了-sigmoid-运算后再求交叉熵的损失函数,无需额外的-sigmoid-运算。"><a href="#注意:nn-BCEWithLogitsLoss-是先进行了-sigmoid-运算后再求交叉熵的损失函数,无需额外的-sigmoid-运算。" class="headerlink" title="注意:nn.BCEWithLogitsLoss() 是先进行了 sigmoid 运算后再求交叉熵的损失函数,无需额外的 sigmoid 运算。"></a>注意:<code>nn.BCEWithLogitsLoss()</code> 是先进行了 sigmoid 运算后再求交叉熵的损失函数,无需额外的 sigmoid 运算。</h5><p>然后我们再定义一个求准确率的函数:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">binary_accuracy</span><span class="params">(preds, y)</span>:</span></span><br><span class="line"> rounded_preds = torch.round(torch.sigmoid(preds))</span><br><span class="line"> correct = (rounded_preds == y).float()</span><br><span class="line"> acc = correct.sum() / len(correct)</span><br><span class="line"> <span class="keyword">return</span> acc</span><br></pre></td></tr></table></figure><p>紧接着我们开始定义训练和验证的函数:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 训练函数</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">train</span><span class="params">(model, iterator, optimizer, criterion)</span>:</span></span><br><span class="line"> epoch_loss = <span class="number">0</span></span><br><span class="line"> epoch_acc = <span class="number">0</span></span><br><span class="line"> </span><br><span class="line"> model.train() <span class="comment"># 训练模式</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> text, label <span class="keyword">in</span> iterator:</span><br><span class="line"> optimizer.zero_grad()</span><br><span class="line"> preds = model(text)</span><br><span class="line"> loss = criterion(preds.squeeze(), label.float())</span><br><span class="line"> acc = binary_accuracy(preds.squeeze(), label)</span><br><span class="line"> loss.backward()</span><br><span class="line"> optimizer.step()</span><br><span class="line"> </span><br><span class="line"> epoch_loss += loss.item()</span><br><span class="line"> epoch_acc += acc.item()</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> epoch_loss / len(iterator), epoch_acc / len(iterator)</span><br><span class="line"> </span><br><span class="line"><span class="comment"># 验证函数</span></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">evaluate</span><span class="params">(model, iterator, criterion)</span>:</span></span><br><span class="line"> epoch_loss = <span class="number">0</span></span><br><span class="line"> epoch_acc = <span class="number">0</span></span><br><span class="line"> </span><br><span class="line"> model.eval() <span class="comment"># 验证模式</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">with</span> torch.no_grad():</span><br><span class="line"> <span class="keyword">for</span> text, label <span class="keyword">in</span> iterator:</span><br><span class="line"> preds = model(text)</span><br><span class="line"> loss = criterion(preds.squeeze(), label.float())</span><br><span class="line"> acc = binary_accuracy(preds.squeeze(), label)</span><br><span class="line"></span><br><span class="line"> epoch_loss += loss.item()</span><br><span class="line"> epoch_acc += acc.item()</span><br><span class="line"></span><br><span class="line"> <span class="keyword">return</span> epoch_loss / len(iterator), epoch_acc / len(iterator)</span><br></pre></td></tr></table></figure><p>可以看到,训练函数与验证函数大同小异,主要区别在于:</p><ol><li>训练模式下权重更新,验证模式下权重不更新;</li><li>验证模式没有优化器。<h5 id="注意:在计算损失函数时,真实标签也要转换成-float-格式,否则会报错。"><a href="#注意:在计算损失函数时,真实标签也要转换成-float-格式,否则会报错。" class="headerlink" title="注意:在计算损失函数时,真实标签也要转换成 float 格式,否则会报错。"></a>注意:在计算损失函数时,真实标签也要转换成 float 格式,否则会报错。</h5>下面就可以构建真正的训练、评估循环了:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> time</span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">epoch_time</span><span class="params">(start_time, end_time)</span>:</span> <span class="comment"># 计算每一轮花费的时间</span></span><br><span class="line"> elapsed_time = end_time - start_time</span><br><span class="line"> elapsed_mins = int(elapsed_time / <span class="number">60</span>)</span><br><span class="line"> elapsed_secs = int(elapsed_time - elapsed_mins * <span class="number">60</span>)</span><br><span class="line"> <span class="keyword">return</span> elapsed_mins, elapsed_secs</span><br><span class="line"> </span><br><span class="line">N_EPOCHS = <span class="number">10</span></span><br><span class="line">best_test_loss = float(<span class="string">'inf'</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> range(N_EPOCHS):</span><br><span class="line"> start_time = time.time()</span><br><span class="line"> </span><br><span class="line"> train_loss, train_acc = train(model, train_iter, optimizer, criterion)</span><br><span class="line"> test_loss, test_acc = evaluate(model, test_iter, criterion)</span><br><span class="line"> </span><br><span class="line"> end_time = time.time()</span><br><span class="line"> epoch_mins, epoch_secs = epoch_time(start_time, end_time)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> test_loss < best_test_loss:</span><br><span class="line"> best_test_loss = test_loss</span><br><span class="line"> torch.save(model.state_dict(), <span class="string">'model.pt'</span>)</span><br><span class="line"> </span><br><span class="line"> print(<span class="string">f'Epoch: <span class="subst">{epoch+<span class="number">1</span>:<span class="number">02</span>}</span> | Epoch Time: <span class="subst">{epoch_mins}</span>m <span class="subst">{epoch_secs}</span>s'</span>)</span><br><span class="line"> print(<span class="string">f'\tTrain Loss: <span class="subst">{train_loss:<span class="number">.3</span>f}</span> | Train Acc: <span class="subst">{train_acc*<span class="number">100</span>:<span class="number">.2</span>f}</span>%'</span>)</span><br><span class="line"> print(<span class="string">f'\t Val. Loss: <span class="subst">{test_loss:<span class="number">.3</span>f}</span> | Val. Acc: <span class="subst">{test_acc*<span class="number">100</span>:<span class="number">.2</span>f}</span>%'</span>)</span><br></pre></td></tr></table></figure>我们进行 10 轮训练,如果验证集的准确率大于最大准确率,则保存模型。最佳结果为:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">Epoch: <span class="number">10</span> | Epoch Time: <span class="number">4</span>m <span class="number">5</span>s</span><br><span class="line">Train Loss: <span class="number">0.171</span> | Train Acc: <span class="number">93.41</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.426</span> | Val. Acc: <span class="number">82.85</span>%</span><br></pre></td></tr></table></figure>这个结果马马虎虎,希望在后面将模型改进后,模型的表现会更好。</li></ol><p>可以在 <a href="https://github.com/vincent507cpu/nlp\_project/blob/master/text%20classification/01%20CNN.ipynb" target="_blank" rel="noopener">https://github.com/vincent507cpu/nlp\_project/blob/master/text%20classification/01%20CNN.ipynb</a> 查看全部代码。</p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
<tag> NLP </tag>
</tags>
</entry>
<entry>
<title>[NLP] 新手的第一个 NLP 任务:文本分类(1)</title>
<link href="2020/08/01/NLP-%E6%96%B0%E6%89%8B%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA-NLP-%E4%BB%BB%E5%8A%A1%EF%BC%9A%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%EF%BC%881%EF%BC%89/"/>
<url>2020/08/01/NLP-%E6%96%B0%E6%89%8B%E7%9A%84%E7%AC%AC%E4%B8%80%E4%B8%AA-NLP-%E4%BB%BB%E5%8A%A1%EF%BC%9A%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%EF%BC%881%EF%BC%89/</url>
<content type="html"><![CDATA[<p>从终端任务来说,NLP 任务有文本分类、文本生成、翻译、文本摘要等等,其中文本分类是一个比较基础的任务。所以让我们从文本分类开始练习,从最简单的模型开始做起,然后尽量一步步提高它的性能。</p><p>文本分类有主题分类和感情分类两种。其中感情分类又比主题分类更加简单一点,因为很多感情分类是二分类任务(主题分类其实也可以,但是一般很少只分两个主题),所以我们将使用 IMDB 电影评论数据集进行一个感情分类任务。</p><a id="more"></a><h1 id="NLP-的-pipeline"><a href="#NLP-的-pipeline" class="headerlink" title="NLP 的 pipeline"></a>NLP 的 pipeline</h1><p>简单来说,NLP 的 pipeline 的主要步骤为:</p><ol><li>载入数据;</li><li>数据探索与分析(EDA);</li><li>数据预处理;</li><li>数据的封装;</li><li>构建模型;</li><li>训练模型;</li><li>评估模型;</li><li>(可选)模型的推断。</li></ol><p>本文主要关注第 1、3、4 步。数据分析这里就略过了,因为 1)这个数据集是一个很经典的数据集,网上已经有无数人做了 EDA;2)我对 <code>pandas</code> 和 <code>matplotlib</code> 还不熟。因为我们现在要构建一个基线模型,采用的方法也比较原始,后面会介绍更高效、简便的方式。</p><h1 id="准备工作"><a href="#准备工作" class="headerlink" title="准备工作"></a>准备工作</h1><p>首先安装、升级所需的库(代码在 Jupyter Notebook 里运行,在 shell 里运行需要把每个命令前面的 <code>!</code> 去掉):</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">!pip install -U tqdm <span class="comment"># 4.48.0</span></span><br><span class="line">!pip install -U nltk <span class="comment"># 3.5</span></span><br><span class="line">!pip install -U spacy <span class="comment"># 2.3.2</span></span><br><span class="line">!pip install -U numpy <span class="comment"># 1.19.1</span></span><br><span class="line">!pip install -U pandas <span class="comment"># 1.1.0</span></span><br><span class="line">!pip install -U sklearn <span class="comment"># 0.23</span></span><br><span class="line">!pip install -U torch <span class="comment"># 1.6</span></span><br><span class="line">!pip install -U torchtext <span class="comment"># 0.7.0</span></span><br></pre></td></tr></table></figure><p>后续文章中默认使用以上最新的库。然后下载 <code>spacy</code> 和 <code>nltk</code> 的数据:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">!python -m spacy download en_core_web_md</span><br><span class="line"></span><br><span class="line"><span class="keyword">from</span> nltk.stem <span class="keyword">import</span> WordNetLemmatizer</span><br><span class="line">nltk.download()</span><br></pre></td></tr></table></figure><h1 id="载入数据"><a href="#载入数据" class="headerlink" title="载入数据"></a>载入数据</h1><p>我们首先使用 <code>pandas</code> 读取 <code>csv</code> 文件。IMDB 电影评论一共有 50000 条,分为 <code>positive</code> 和 <code>negative</code> 两种。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"></span><br><span class="line">data = pd.read_csv(<span class="string">'.../datasets/IMDB Dataset.csv'</span>)</span><br></pre></td></tr></table></figure><h1 id="数据预处理"><a href="#数据预处理" class="headerlink" title="数据预处理"></a>数据预处理</h1><p>对于 NLP 任务来说,数据即文本。文本预处理任务一般有:</p><ol><li>文本清洗(去除乱码、停用词等);</li><li>分词;</li><li>(仅限英文)将词语进行还原;</li><li>文本的截取与补全;</li><li>构建词汇表;</li><li>创建一个将 token 转换为 id 的映射并将文本转换为 id(有时候还需要创建一个将 id 转换为token 的映射)。</li></ol><p><code>nltk</code> 和 <code>spacy</code> 是处理英文 NLP 任务的两个常用的库。本来我习惯使用 <code>nltk</code> 进行分词,然而发现 <code>nltk</code> 的效果没有 <code>spacy</code> 好。所以我这次使用 <code>spacy</code> 进行分词,使用 <code>nltk</code> 将词语还原成原型。</p><p>首先做一些准备工作:</p><figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">from nltk.stem import WordNetLemmatizer</span><br><span class="line">lemmatizer = WordNetLemmatizer() # 初始化 lemmatizer</span><br><span class="line"></span><br><span class="line">import spacy</span><br><span class="line">nlp = spacy.load('en_core_web_md') # 初始化语言处理引擎,用于分词</span><br></pre></td></tr></table></figure><p>因为深度学习模型只能处理数字,我们需要将文本转换为数字。我把所有的事情放在一起做了:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> tqdm <span class="keyword">import</span> tqdm</span><br><span class="line"><span class="keyword">import</span> re</span><br><span class="line"></span><br><span class="line">processed_review = []</span><br><span class="line">sentiment = []</span><br><span class="line"></span><br><span class="line">word2id = {<span class="string">'<PAD>'</span>:<span class="number">0</span>} <span class="comment"># token 到 id 的映射</span></span><br><span class="line"><span class="comment"># id2word = {0:'<PAD>'} # id 到 token 的映射,这个任务用不到</span></span><br><span class="line">vocab = set([<span class="string">'<PAD>'</span>]) <span class="comment"># 词汇表</span></span><br><span class="line">count = <span class="number">1</span></span><br><span class="line">SEQ_LEN = <span class="number">100</span> <span class="comment"># 每条文本的固定长度</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> i <span class="keyword">in</span> tqdm(range(len(data))): <span class="comment"># tqdm 显示进度</span></span><br><span class="line"> text = data.review[i].lower() <span class="comment"># 转换为小写</span></span><br><span class="line"> text = re.sub(<span class="string">'<.+?>'</span>, <span class="string">''</span>, text) <span class="comment"># 去掉 HTML 文本</span></span><br><span class="line"> text = re.sub(<span class="string">'[<>]'</span>, <span class="string">''</span>, text) <span class="comment"># 去掉 HTML 文本</span></span><br><span class="line"> text = [lemmatizer.lemmatize(token.text) <span class="keyword">for</span> token <span class="keyword">in</span> nlp.tokenizer(text)][:SEQ_LEN] <span class="comment"># 先分词,再还原,最后截取</span></span><br><span class="line"></span><br><span class="line"> tmp = [<span class="number">0</span>] * (SEQ_LEN - len(text)) <span class="keyword">if</span> len(text) < SEQ_LEN <span class="keyword">else</span> [] <span class="comment"># 用 0 补全短文本</span></span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 构建词汇表以及映射</span></span><br><span class="line"> <span class="keyword">for</span> word <span class="keyword">in</span> text:</span><br><span class="line"> <span class="keyword">if</span> word <span class="keyword">not</span> <span class="keyword">in</span> vocab:</span><br><span class="line"> vocab.add(word)</span><br><span class="line"> word2id[word] = count</span><br><span class="line"> tmp.append(count)</span><br><span class="line"> count += <span class="number">1</span></span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> tmp.append(word2id[word])</span><br><span class="line"></span><br><span class="line"> processed_review.append(tmp)</span><br><span class="line"> </span><br><span class="line"> <span class="comment"># 将 positive 转换 为 1,将 negative 转换为 0</span></span><br><span class="line"> <span class="keyword">if</span> data.sentiment[i] == <span class="string">'positive'</span>:</span><br><span class="line"> sentiment.append(<span class="number">1</span>)</span><br><span class="line"> <span class="keyword">elif</span> data.sentiment[i] == <span class="string">'negative'</span>:</span><br><span class="line"> sentiment.append(<span class="number">0</span>)</span><br></pre></td></tr></table></figure><h1 id="数据封装"><a href="#数据封装" class="headerlink" title="数据封装"></a>数据封装</h1><p>现在数据和标签都变成了数字,然后是划分训练集和测试集(我们暂时不用验证集)。这里使用 <code>sklearn</code> 里的函数实现,产生 40000 条训练集和 10000 条测试集。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> sklearn.model_selection <span class="keyword">import</span> train_test_split</span><br><span class="line"></span><br><span class="line">X_train, X_test, y_train, y_test = train_test_split(processed_review, sentiment, train_size=<span class="number">0.8</span>, random_state=<span class="number">1988</span>)</span><br></pre></td></tr></table></figure><p>在构建模型之前的最后一步是封装数据,以便以 batch 的数量将数据送进网络。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torch.utils.data <span class="keyword">import</span> TensorDataset, DataLoader</span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"></span><br><span class="line">BATCH_SIZE = <span class="number">64</span></span><br><span class="line"></span><br><span class="line">train_ds = TensorDataset(torch.as_tensor(X_train), torch.as_tensor(y_train))</span><br><span class="line">test_ds = TensorDataset(torch.as_tensor(X_test), torch.as_tensor(y_test))</span><br><span class="line"></span><br><span class="line">train_iter = DataLoader(train_ds, batch_size=BATCH_SIZE, drop_last=<span class="literal">True</span>) <span class="comment"># (BATCH_SIZE, SEQ_LEN)</span></span><br><span class="line">test_iter = DataLoader(test_ds, batch_size=BATCH_SIZE, drop_last=<span class="literal">True</span>) <span class="comment"># (BATCH_SIZE, )</span></span><br></pre></td></tr></table></figure><p>首先使用 <code>TensorDataset</code> 将训练集和测试集转换成 PyTorch 可以识别的格式,然后使用 <code>DataLoader</code> 将数据集进行封装,生成一个以 <code>BATCH_SIZE</code> 为读取批量的生成器。</p><p>下一篇文章将进行建模和训练。可以在 <a href="https://github.com/vincent507cpu/nlp_project/blob/master/text%20classification/01%20CNN.ipynb" target="_blank" rel="noopener">https://github.com/vincent507cpu/nlp_project/blob/master/text%20classification/01%20CNN.ipynb</a> 查看全部代码。</p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
<tag> NLP </tag>
</tags>
</entry>
<entry>
<title>DL-PyTorch-折桂-18:使用-TorchText-和-transformers-进行情感分类(2)</title>
<link href="2020/08/01/DL-PyTorch-%E6%8A%98%E6%A1%82-18%EF%BC%9A%E4%BD%BF%E7%94%A8-TorchText-%E5%92%8C-transformers-%E8%BF%9B%E8%A1%8C%E6%83%85%E6%84%9F%E5%88%86%E7%B1%BB-2/"/>
<url>2020/08/01/DL-PyTorch-%E6%8A%98%E6%A1%82-18%EF%BC%9A%E4%BD%BF%E7%94%A8-TorchText-%E5%92%8C-transformers-%E8%BF%9B%E8%A1%8C%E6%83%85%E6%84%9F%E5%88%86%E7%B1%BB-2/</url>
<content type="html"><![CDATA[<p><a href="https://vincent507cpu.github.io/2020/06/10/DL-PyTorch-折桂-17:使用-TorchText-和-transformers-进行情感分类/">接上文</a></p><h1 id="9-搭建模型"><a href="#9-搭建模型" class="headerlink" title="9. 搭建模型"></a>9. 搭建模型</h1><p>首先是载入预训练模型。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> transformers <span class="keyword">import</span> BertTokenizer, BertModel</span><br><span class="line"></span><br><span class="line">bert = BertModel.from_pretrained(<span class="string">'bert-base-uncased'</span>)</span><br></pre></td></tr></table></figure><p>我们使用 Bert 预训练词向量与 GRU 组成模型,然后接一个全连接层。我们需要使用 <code>with torch.no_grad()</code> 避免预训练词向量发生变化。</p><a id="more"></a><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br><span class="line">32</span><br><span class="line">33</span><br><span class="line">34</span><br><span class="line">35</span><br><span class="line">36</span><br><span class="line">37</span><br><span class="line">38</span><br><span class="line">39</span><br><span class="line">40</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">BERTGRUSentiment</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self,</span></span></span><br><span class="line"><span class="function"><span class="params"> bert,</span></span></span><br><span class="line"><span class="function"><span class="params"> hidden_dim,</span></span></span><br><span class="line"><span class="function"><span class="params"> output_dim,</span></span></span><br><span class="line"><span class="function"><span class="params"> n_layers,</span></span></span><br><span class="line"><span class="function"><span class="params"> bidirectional,</span></span></span><br><span class="line"><span class="function"><span class="params"> dropout)</span>:</span></span><br><span class="line"> </span><br><span class="line"> super().__init__()</span><br><span class="line"> self.bert = bert</span><br><span class="line"> embedding_dim = bert.config.to_dict()[<span class="string">'hidden_size'</span>]</span><br><span class="line"> self.rnn = nn.GRU(embedding_dim,</span><br><span class="line"> hidden_dim,</span><br><span class="line"> num_layers = n_layers,</span><br><span class="line"> bidirectional = bidirectional,</span><br><span class="line"> batch_first = <span class="literal">True</span>,</span><br><span class="line"> dropout = <span class="number">0</span> <span class="keyword">if</span> n_layers < <span class="number">2</span> <span class="keyword">else</span> dropout)</span><br><span class="line"> </span><br><span class="line"> self.out = nn.Linear(hidden_dim * <span class="number">2</span> <span class="keyword">if</span> bidirectional <span class="keyword">else</span> hidden_dim, output_dim)</span><br><span class="line"> self.dropout = nn.Dropout(<span class="number">0.5</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, text)</span>:</span></span><br><span class="line"> <span class="comment">#text = [batch size, sent len]</span></span><br><span class="line"> <span class="keyword">with</span> torch.no_grad():</span><br><span class="line"> embedded = self.bert(text)[<span class="number">0</span>]</span><br><span class="line"> <span class="comment">#embedded = [batch size, sent len, emb dim]</span></span><br><span class="line"> _, hidden = self.rnn(embedded)</span><br><span class="line"> </span><br><span class="line"> <span class="comment">#hidden = [n layers * n directions, batch size, emb dim]</span></span><br><span class="line"> <span class="keyword">if</span> self.rnn.bidirectional:</span><br><span class="line"> hidden = self.dropout(torch.cat((hidden[<span class="number">-2</span>,:,:], hidden[<span class="number">-1</span>,:,:]), dim = <span class="number">1</span>))</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> hidden = self.dropout(hidden[<span class="number">-1</span>,:,:])</span><br><span class="line"> </span><br><span class="line"> <span class="comment">#hidden = [batch size, hid dim]</span></span><br><span class="line"> output = self.out(hidden)</span><br><span class="line"> </span><br><span class="line"> <span class="comment">#output = [batch size, out dim]</span></span><br><span class="line"> <span class="keyword">return</span> output</span><br></pre></td></tr></table></figure><p>接下来我们使用标准超参数将模型实例化。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">HIDDEN_DIM = <span class="number">256</span></span><br><span class="line">OUTPUT_DIM = <span class="number">1</span></span><br><span class="line">N_LAYERS = <span class="number">2</span></span><br><span class="line">BIDIRECTIONAL = <span class="literal">True</span></span><br><span class="line">DROPOUT = <span class="number">0.25</span></span><br><span class="line"></span><br><span class="line">model = BERTGRUSentiment(bert,</span><br><span class="line"> HIDDEN_DIM,</span><br><span class="line"> OUTPUT_DIM,</span><br><span class="line"> N_LAYERS,</span><br><span class="line"> BIDIRECTIONAL,</span><br><span class="line"> DROPOUT)</span><br></pre></td></tr></table></figure><p>由于 transformer 的训练量实在比较大,我们设置不更新 bert 的权重:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> name, param <span class="keyword">in</span> model.named_parameters(): </span><br><span class="line"> <span class="keyword">if</span> name.startswith(<span class="string">'bert'</span>):</span><br><span class="line"> param.requires_grad = <span class="literal">False</span></span><br></pre></td></tr></table></figure><h1 id="10-训练模型"><a href="#10-训练模型" class="headerlink" title="10. 训练模型"></a>10. 训练模型</h1><p>优化器使用 Adam,损失函数使用 <code>nn.BCEWithLogitsLoss()</code>。除此以外,我们再定义一个评价准确率的函数:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">binary_accuracy</span><span class="params">(preds, y)</span>:</span></span><br><span class="line"> <span class="string">"""</span></span><br><span class="line"><span class="string"> Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8</span></span><br><span class="line"><span class="string"> """</span></span><br><span class="line"> <span class="comment">#round predictions to the closest integer</span></span><br><span class="line"> rounded_preds = torch.round(torch.sigmoid(preds))</span><br><span class="line"> correct = (rounded_preds == y).float() <span class="comment">#convert into float for division </span></span><br><span class="line"> acc = correct.sum() / len(correct)</span><br><span class="line"> <span class="keyword">return</span> acc</span><br></pre></td></tr></table></figure><p>训练函数:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">train</span><span class="params">(model, iterator, optimizer, criterion)</span>:</span></span><br><span class="line"> epoch_loss = <span class="number">0</span></span><br><span class="line"> epoch_acc = <span class="number">0</span></span><br><span class="line"> </span><br><span class="line"> model.train()</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">for</span> batch <span class="keyword">in</span> iterator:</span><br><span class="line"> optimizer.zero_grad()</span><br><span class="line"> predictions = model(batch.text).squeeze(<span class="number">1</span>)</span><br><span class="line"> loss = criterion(predictions, batch.label)</span><br><span class="line"> acc = binary_accuracy(predictions, batch.label)</span><br><span class="line"> loss.backward()</span><br><span class="line"> optimizer.step()</span><br><span class="line"> epoch_loss += loss.item()</span><br><span class="line"> epoch_acc += acc.item()</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> epoch_loss / len(iterator), epoch_acc / len(iterator)</span><br></pre></td></tr></table></figure><p>验证函数与训练函数类似,区别在于:</p><ol><li>不更新权重;</li><li>没有优化器。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">evaluate</span><span class="params">(model, iterator, criterion)</span>:</span></span><br><span class="line"> epoch_loss = <span class="number">0</span></span><br><span class="line"> epoch_acc = <span class="number">0</span></span><br><span class="line"> </span><br><span class="line"> model.eval()</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">with</span> torch.no_grad():</span><br><span class="line"> <span class="keyword">for</span> batch <span class="keyword">in</span> iterator:</span><br><span class="line"> predictions = model(batch.text).squeeze(<span class="number">1</span>)</span><br><span class="line"> loss = criterion(predictions, batch.label)</span><br><span class="line"> acc = binary_accuracy(predictions, batch.label)</span><br><span class="line"> epoch_loss += loss.item()</span><br><span class="line"> epoch_acc += acc.item()</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> epoch_loss / len(iterator), epoch_acc / len(iterator)</span><br></pre></td></tr></table></figure>训练函数与验证函数都写好了以后就可以进行真正的训练了。在每一轮里,我们首先更新权重,然后用新的权重去验证。如果验证的损失小于之前的最小值,我们保存当前的模型。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line">N_EPOCHS = <span class="number">10</span></span><br><span class="line">best_valid_loss = float(<span class="string">'inf'</span>)</span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> epoch <span class="keyword">in</span> range(N_EPOCHS):</span><br><span class="line"> train_loss, train_acc = train(model, train_iterator, optimizer, criterion)</span><br><span class="line"> valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span> valid_loss < best_valid_loss:</span><br><span class="line"> best_valid_loss = valid_loss</span><br><span class="line"> torch.save(model.state_dict(), <span class="string">'tut6-model.pt'</span>)</span><br><span class="line"> </span><br><span class="line"> print(<span class="string">f'Epoch: <span class="subst">{epoch+<span class="number">1</span>:<span class="number">02</span>}</span></span></span><br><span class="line"><span class="string"> print(f'</span>\tTrain Loss: {train_loss:<span class="number">.3</span>f} | Train Acc: {train_acc*<span class="number">100</span>:<span class="number">.2</span>f}%<span class="string">')</span></span><br><span class="line"><span class="string"> print(f'</span>\t Val. Loss: {valid_loss:<span class="number">.3</span>f} | Val. Acc: {valid_acc*<span class="number">100</span>:<span class="number">.2</span>f}%<span class="string">')</span></span><br></pre></td></tr></table></figure>训练过程如下:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br></pre></td><td class="code"><pre><span class="line">Epoch: <span class="number">01</span> | Epoch Time: <span class="number">7</span>m <span class="number">5</span>s</span><br><span class="line">Train Loss: <span class="number">0.468</span> | Train Acc: <span class="number">76.80</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.266</span> | Val. Acc: <span class="number">89.47</span>%</span><br><span class="line">Epoch: <span class="number">02</span> | Epoch Time: <span class="number">7</span>m <span class="number">4</span>s</span><br><span class="line">Train Loss: <span class="number">0.280</span> | Train Acc: <span class="number">88.42</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.244</span> | Val. Acc: <span class="number">90.20</span>%</span><br><span class="line">Epoch: <span class="number">03</span> | Epoch Time: <span class="number">7</span>m <span class="number">4</span>s</span><br><span class="line">Train Loss: <span class="number">0.239</span> | Train Acc: <span class="number">90.48</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.220</span> | Val. Acc: <span class="number">91.07</span>%</span><br><span class="line">Epoch: <span class="number">04</span> | Epoch Time: <span class="number">7</span>m <span class="number">4</span>s</span><br><span class="line">Train Loss: <span class="number">0.211</span> | Train Acc: <span class="number">91.66</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.236</span> | Val. Acc: <span class="number">90.85</span>%</span><br><span class="line">Epoch: <span class="number">05</span> | Epoch Time: <span class="number">7</span>m <span class="number">5</span>s</span><br><span class="line">Train Loss: <span class="number">0.187</span> | Train Acc: <span class="number">92.91</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.222</span> | Val. Acc: <span class="number">91.12</span>%</span><br><span class="line">Epoch: <span class="number">06</span> | Epoch Time: <span class="number">7</span>m <span class="number">5</span>s</span><br><span class="line">Train Loss: <span class="number">0.164</span> | Train Acc: <span class="number">93.71</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.251</span> | Val. Acc: <span class="number">91.29</span>%</span><br><span class="line">Epoch: <span class="number">07</span> | Epoch Time: <span class="number">7</span>m <span class="number">4</span>s</span><br><span class="line">Train Loss: <span class="number">0.137</span> | Train Acc: <span class="number">94.94</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.231</span> | Val. Acc: <span class="number">90.73</span>%</span><br><span class="line">Epoch: <span class="number">08</span> | Epoch Time: <span class="number">7</span>m <span class="number">4</span>s</span><br><span class="line">Train Loss: <span class="number">0.115</span> | Train Acc: <span class="number">95.73</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.374</span> | Val. Acc: <span class="number">86.99</span>%</span><br><span class="line">Epoch: <span class="number">09</span> | Epoch Time: <span class="number">7</span>m <span class="number">4</span>s</span><br><span class="line">Train Loss: <span class="number">0.095</span> | Train Acc: <span class="number">96.57</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.259</span> | Val. Acc: <span class="number">91.22</span>%</span><br><span class="line">Epoch: <span class="number">10</span> | Epoch Time: <span class="number">7</span>m <span class="number">5</span>s</span><br><span class="line">Train Loss: <span class="number">0.078</span> | Train Acc: <span class="number">97.30</span>%</span><br><span class="line"> Val. Loss: <span class="number">0.282</span> | Val. Acc: <span class="number">91.77</span>%</span><br></pre></td></tr></table></figure><h1 id="11-模型推断"><a href="#11-模型推断" class="headerlink" title="11. 模型推断"></a>11. 模型推断</h1>训练好模型以后,我们可以用这个模型来做推断。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">predict_sentiment</span><span class="params">(model, tokenizer, sentence)</span>:</span></span><br><span class="line"> model.eval()</span><br><span class="line"> tokens = tokenizer.tokenize(sentence)</span><br><span class="line"> tokens = tokens[:max_input_length<span class="number">-2</span>]</span><br><span class="line"> indexed = [init_token_idx] + tokenizer.convert_tokens_to_ids(tokens) + [eos_token_idx]</span><br><span class="line"> tensor = torch.LongTensor(indexed).to(device)</span><br><span class="line"> tensor = tensor.unsqueeze(<span class="number">0</span>)</span><br><span class="line"> prediction = torch.sigmoid(model(tensor))</span><br><span class="line"> <span class="keyword">return</span> prediction.item()</span><br><span class="line"> </span><br><span class="line">predict_sentiment(model, tokenizer, <span class="string">"This film is terrible"</span>) <span class="comment"># 0.021611081436276436</span></span><br><span class="line">predict_sentiment(model, tokenizer, <span class="string">"This film is great"</span>) <span class="comment"># 0.9428628087043762</span></span><br></pre></td></tr></table></figure>上面直接返回概率,也可以处理一下返回 positive 或 negative。</li></ol>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
<tag> NLP </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 17:使用 TorchText 和 transformers 进行情感分类(1)</title>
<link href="2020/06/10/DL-PyTorch-%E6%8A%98%E6%A1%82-17%EF%BC%9A%E4%BD%BF%E7%94%A8-TorchText-%E5%92%8C-transformers-%E8%BF%9B%E8%A1%8C%E6%83%85%E6%84%9F%E5%88%86%E7%B1%BB/"/>
<url>2020/06/10/DL-PyTorch-%E6%8A%98%E6%A1%82-17%EF%BC%9A%E4%BD%BF%E7%94%A8-TorchText-%E5%92%8C-transformers-%E8%BF%9B%E8%A1%8C%E6%83%85%E6%84%9F%E5%88%86%E7%B1%BB/</url>
<content type="html"><![CDATA[<p> 我们已经了解了 PyTorch 的基本操作和功能,现在让我们实践一下。自从 transformer 横空出世以后,在 NLP 领域有”大一统“ 的趋势。但 transformer 的本质是什么?transformer 的本质是一个能够有效提取语义信息的词嵌入生成器,它比前辈 word2vec、GloVe 等等能够更有效地提取词语的语义信息,所以以 transformer 生成的词嵌入可以有 SOTA(state-of-the-art,最高水平)的性能。这等于电脑可以更好地理解文本中每个词语的意思,理解了每个词语的意思自然就可以更好地理解文本的整体意思。所以 transformer 只是取代了以前用的 Embedding 层,根据具体的任务的不同还可以接上 CNN、RNN 等层。</p><p> 本文及下一篇文章中,我们将使用 PyTorch,TorchText 和 transformers 库里的 Bert 预训练模型来进行一个基本的情感分类任务:IMDB 影片评论的情感分类。</p><a id="more"></a> <p> Bert 之类的 transformer 预训练模型虽然性能强大,它也有一个致命的缺点,即资源的消耗非常巨大。简单的模型跑小数据库还可以在个人 PC 上运行,Bert 之类的模型必须用到 GPU。写这篇文章的时候,我是在 Google Colab 上完成模型的训练的。<br> <figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span> !nvidia-smi</span><br><span class="line"> Wed Jun <span class="number">10</span> <span class="number">18</span>:<span class="number">00</span>:<span class="number">42</span> <span class="number">2020</span> </span><br><span class="line">+-----------------------------------------------------------------------------+</span><br><span class="line">| NVIDIA-SMI <span class="number">440.82</span> Driver Version: <span class="number">418.67</span> CUDA Version: <span class="number">10.1</span> |</span><br><span class="line">|-------------------------------+----------------------+----------------------+</span><br><span class="line">| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |</span><br><span class="line">| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |</span><br><span class="line">|===============================+======================+======================|</span><br><span class="line">| <span class="number">0</span> Tesla P100-PCIE... Off | <span class="number">00000000</span>:<span class="number">00</span>:<span class="number">04.0</span> Off | <span class="number">0</span> |</span><br><span class="line">| N/A <span class="number">69</span>C P0 <span class="number">48</span>W / <span class="number">250</span>W | <span class="number">9149</span>MiB / <span class="number">16280</span>MiB | <span class="number">0</span>% Default |</span><br><span class="line">+-------------------------------+----------------------+----------------------+</span><br><span class="line"> </span><br><span class="line">+-----------------------------------------------------------------------------+</span><br><span class="line">| Processes: GPU Memory |</span><br><span class="line">| GPU PID Type Process name Usage |</span><br><span class="line">|=============================================================================|</span><br><span class="line">+-----------------------------------------------------------------------------+</span><br></pre></td></tr></table></figure><br> 本文的代码来自 <a href="https://colab.research.google.com/github/bentrevett/pytorch-sentiment-analysis/blob/master/6%20-%20Transformers%20for%20Sentiment%20Analysis.ipynb#scrollTo=xZ1S1o1iH5jT" target="_blank" rel="noopener" title="Transformers for Sentiment Analysis">Transformers for Sentiment Analysis</a>。</p><p> *注:本文的很多资源可能需要科学上网,不能科学上网的话我也没办法哈。</p><h1 id="1-必要库的加载"><a href="#1-必要库的加载" class="headerlink" title="1. 必要库的加载"></a>1. 必要库的加载</h1><p> 首先安装最新版 PyTorch,TorchText 和 transformers。<br> <figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line">!pip install -U torchtext</span><br><span class="line">!pip install -U torch</span><br><span class="line">!pip install -U transformers</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> random</span><br><span class="line"></span><br><span class="line">SEED = <span class="number">1234</span></span><br><span class="line"><span class="comment"># 初始化随机种子</span></span><br><span class="line">random.seed(SEED)</span><br><span class="line">np.random.seed(SEED)</span><br><span class="line">torch.manual_seed(SEED)</span><br><span class="line">torch.backends.cudnn.deterministic = <span class="literal">True</span> <span class="comment"># 可以加快训练速度一点点</span></span><br></pre></td></tr></table></figure><br>虽然每一步都会加载必需的库,我还是会把完整路径写出来,方便确认从属。</p><h1 id="2-获取数据:"><a href="#2-获取数据:" class="headerlink" title="2. 获取数据:"></a>2. 获取数据:</h1><p>在 Kaggle 上下载<a href="https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews" target="_blank" rel="noopener" title="IMDB 电影评论数据">数据</a>。</p><h1 id="3-分词器的准备"><a href="#3-分词器的准备" class="headerlink" title="3. 分词器的准备"></a>3. 分词器的准备</h1><p>我们将使用 Bert 预训练模型的分词器。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> transformers <span class="keyword">import</span> BertTokenizer</span><br><span class="line"></span><br><span class="line">tokenizer = transformers.BertTokenizer.from_pretrained(<span class="string">'bert-base-uncased'</span>)</span><br></pre></td></tr></table></figure><p>‘uncased’ 意味着这个分词器是不区分大小写的,意味着仅仅处理小写字母。不用担心,分词器会自动把文本转换成小写形式再处理。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>tokens = tokenizer.tokenize(<span class="string">'Hello WORLD how ARE yoU?'</span>) <span class="comment"># 大小写不敏感的预训练模型会自动转换大小写</span></span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(tokens)</span><br><span class="line">[<span class="string">'hello'</span>, <span class="string">'world'</span>, <span class="string">'how'</span>, <span class="string">'are'</span>, <span class="string">'you'</span>, <span class="string">'?'</span>]</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>indexes = tokenizer.convert_tokens_to_ids(tokens) <span class="comment"># 找到 token 对应的 id</span></span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(indexes)</span><br><span class="line">[<span class="number">7592</span>, <span class="number">2088</span>, <span class="number">2129</span>, <span class="number">2024</span>, <span class="number">2017</span>, <span class="number">1029</span>]</span><br></pre></td></tr></table></figure><p>接下来我们需要指定 4 个特殊的 token:句起始 token,句结束 token,填充 token 和未知词语 token 备用。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>init_token_idx = tokenizer.cls_token_id <span class="comment"># 起始</span></span><br><span class="line"><span class="meta">>>> </span>eos_token_idx = tokenizer.sep_token_id <span class="comment"># 结束</span></span><br><span class="line"><span class="meta">>>> </span>pad_token_idx = tokenizer.pad_token_id <span class="comment"># 填充</span></span><br><span class="line"><span class="meta">>>> </span>unk_token_idx = tokenizer.unk_token_id <span class="comment"># 未知</span></span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(init_token_idx, eos_token_idx, pad_token_idx, unk_token_idx)</span><br><span class="line"><span class="number">101</span> <span class="number">102</span> <span class="number">0</span> <span class="number">100</span></span><br></pre></td></tr></table></figure><p>然后我们还需要获得预训练 Bert 模型的序列长度。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>max_input_length = tokenizer.max_model_input_sizes[<span class="string">'bert-base-uncased'</span>]</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(max_input_length)</span><br><span class="line"><span class="number">512</span></span><br></pre></td></tr></table></figure><h1 id="4-构建分词器"><a href="#4-构建分词器" class="headerlink" title="4. 构建分词器"></a>4. 构建分词器</h1><p>上一篇文章里说我们可以使用 spacy 作为分词器,这里我们使用 <code>BertTokenizer</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">tokenize_and_cut</span><span class="params">(sentence)</span>:</span></span><br><span class="line"> tokens = tokenizer.tokenize(sentence) </span><br><span class="line"> tokens = tokens[:max_input_length<span class="number">-2</span>]</span><br><span class="line"> <span class="keyword">return</span> tokens</span><br></pre></td></tr></table></figure><p>因为预训练 Bert 模型的最长序列为 512,为给数据点加上 <code>[CLS]</code> 和 <code>[EOS]</code>,我们需要把分词后的序列长度减 2.</p><h1 id="5-定义-field"><a href="#5-定义-field" class="headerlink" title="5. 定义 field"></a>5. 定义 field</h1><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torchtext <span class="keyword">import</span> data</span><br><span class="line"></span><br><span class="line">TEXT = torchtext.data.Field(batch_first = <span class="literal">True</span>,</span><br><span class="line"> use_vocab = <span class="literal">False</span>,</span><br><span class="line"> tokenize = tokenize_and_cut,</span><br><span class="line"> preprocessing = tokenizer.convert_tokens_to_ids,</span><br><span class="line"> init_token = init_token_idx,</span><br><span class="line"> eos_token = eos_token_idx,</span><br><span class="line"> pad_token = pad_token_idx,</span><br><span class="line"> unk_token = unk_token_idx)</span><br><span class="line"></span><br><span class="line">LABEL = torchtext.data.LabelField(dtype = torch.float)</span><br></pre></td></tr></table></figure><p>因为 batch 在第一维,所以我们设定 <code>batch_first = True</code>。由于我们已经有了单词表(bert 的 embedding),所以需要设置 <code>use_vocab = False</code>。然后我们在 <code>preprocessing</code> 这里将 token 转换成对应的 id。最后,我需要定义特殊的 token id。</p><h1 id="6-加载数据"><a href="#6-加载数据" class="headerlink" title="6. 加载数据"></a>6. 加载数据</h1><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">from</span> torchtext <span class="keyword">import</span> datasets</span><br><span class="line"></span><br><span class="line">train_data, test_data = torchtext.datasets.IMDB.splits(TEXT, LABEL)</span><br><span class="line"></span><br><span class="line">train_data, valid_data = train_data.split(random_state = random.seed(SEED))</span><br></pre></td></tr></table></figure><p><code>IMDB</code> 数据库使用 <code>splts</code> 方法创建训练集和测试集,具体参数如下:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">splits(text_field, label_field, root=<span class="string">'.data'</span>, train=<span class="string">'train'</span>, test=<span class="string">'test'</span>, **kwargs)</span><br></pre></td></tr></table></figure><p>两个必需参数 <code>text_field</code> 和 <code>label_field</code> 分别对应了文本与标签的 field。</p><h1 id="7-建立标签的词汇表"><a href="#7-建立标签的词汇表" class="headerlink" title="7. 建立标签的词汇表"></a>7. 建立标签的词汇表</h1><p>虽然我们已经在定义 field 的时候定义了 TEXT 的词汇表,我们还需要将 LABEL 转换为数字。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">LABEL.build_vocab(train_data)</span><br></pre></td></tr></table></figure><h1 id="8-创建训练集、验证集和测试集"><a href="#8-创建训练集、验证集和测试集" class="headerlink" title="8. 创建训练集、验证集和测试集"></a>8. 创建训练集、验证集和测试集</h1><p>最后我们使用 <code>BucketIterator</code> 创建训练集、验证集和测试集的迭代器。这里我们使用 128 作为 batch size。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">BATCH_SIZE = <span class="number">128</span></span><br><span class="line"></span><br><span class="line">device = torch.device(<span class="string">'cuda'</span> <span class="keyword">if</span> torch.cuda.is_available() <span class="keyword">else</span> <span class="string">'cpu'</span>)</span><br><span class="line"></span><br><span class="line">train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(</span><br><span class="line"> (train_data, valid_data, test_data), </span><br><span class="line"> batch_size = BATCH_SIZE, </span><br><span class="line"> device = device)</span><br></pre></td></tr></table></figure><p>欢迎关注我的微信公众号“花解语 NLP”:<br><img src="https://img-blog.csdnimg.cn/20200514100635366.jpg#pic_center" alt=""></p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
<tag> NLP </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 16:transformers</title>
<link href="2020/06/05/DL-PyTorch-%E6%8A%98%E6%A1%82-16%EF%BC%9Atransformers/"/>
<url>2020/06/05/DL-PyTorch-%E6%8A%98%E6%A1%82-16%EF%BC%9Atransformers/</url>
<content type="html"><![CDATA[<p> 严格意义上讲 transformers 并不是 PyTorch 的一部分,然而 transformers 与 PyTorch 或 TensorFlow 结合的太紧密了,而且可以把 transformers 看成是 PyTorch 或 TensorFlow 的延伸,所以也在这里一并讨论了。<br> <a id="more"></a><br> transformers 内置了 17 种以 transformer 结构为基础的神经网络:</p><ul><li><p>T5 model</p></li><li><p>DistilBERT model</p></li><li><p>ALBERT model</p></li><li><p>CamemBERT model</p></li><li><p>XLM-RoBERTa model</p></li><li><p>Longformer model</p></li><li><p>RoBERTa model</p></li><li><p>Reformer model</p></li><li><p>Bert model</p></li><li><p>OpenAI GPT model</p></li><li><p>OpenAI GPT-2 model</p></li><li><p>Transformer-XL model</p></li><li><p>XLNet model</p></li><li><p>XLM model</p></li><li><p>CTRL model</p></li><li><p>Flaubert model</p></li><li><p>ELECTRA model</p><p>这些模型的参数、用法大同小异。默认框架为 PyTorch,使用 TensorFlow 框架在类的前面加上 ‘TF” 即可。</p><p>每种模型都有至少一个预训练模型,限于篇幅,这里仅仅列举 Bert 的常用预训练模型:</p><table><thead><tr><th align="center">模型</th><th align="center">模型细节</th></tr></thead><tbody><tr><td align="center"><code>bert-base-uncased</code></td><td align="center">12-layer, 768-hidden, 12-heads, 110M parameters. Trained on lower-cased English text.</td></tr><tr><td align="center"><code>bert-large-uncased</code></td><td align="center">24-layer, 1024-hidden, 16-heads, 340M parameters. Trained on lower-cased English text.</td></tr><tr><td align="center"><code>bert-base-cased</code></td><td align="center">12-layer, 768-hidden, 12-heads, 110M parameters. Trained on cased English text.</td></tr><tr><td align="center"><code>bert-large-cased</code></td><td align="center">24-layer, 1024-hidden, 16-heads, 340M parameters. Trained on cased English text.</td></tr><tr><td align="center"><code>bert-base-multilingual-cased</code></td><td align="center">12-layer, 768-hidden, 12-heads, 110M parameters. Trained on cased text in the top 104 languages with the largest Wikipedias</td></tr><tr><td align="center"><code>bert-base-chinese</code></td><td align="center">12-layer, 768-hidden, 12-heads, 110M parameters. Trained on cased Chinese Simplified and Traditional text.</td></tr></tbody></table><p>完整的预训练模型列表可以在 <a href="https://huggingface.co/transformers/pretrained_models.html" target="_blank" rel="noopener" title="Pretrained models">transformers 官网</a>上找到。</p><p>使用 transformers 库有三种方法:</p></li></ul><ol><li><p>使用 <code>pipeline</code>;</p></li><li><p>指定预训练模型;</p></li><li><p>使用 <code>AutoModels</code> 加载预训练模型。</p><h1 id="1-transformers-pipeline"><a href="#1-transformers-pipeline" class="headerlink" title="1. transformers.pipeline"></a>1. <code>transformers.pipeline</code></h1><p>这个管线函数包含三个部分:</p></li><li><p>Tokenizer;</p></li><li><p>一个模型实例;</p></li><li><p>其它增强模型输出的功能。</p><p>它只有一个必需参数 <code>task</code>,接受如下变量之一:</p></li></ol><ul><li>”feature-extraction”</li><li>”sentiment-analysis”</li><li>”ner”</li><li>”question-answering”</li><li>”fill-mask”</li><li>”summarization”</li><li>”translation_xx_to_yy”</li><li>”text-generation”</li></ul><p>这个函数还有其它可选参数,但是我的试用经验是,什么都不要动,使用默认参数即可。</p><p>例子:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="keyword">from</span> transformers <span class="keyword">import</span> pipeline</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>nlp = pipeline(<span class="string">"sentiment-analysis"</span>)</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(nlp(<span class="string">"I hate you"</span>))</span><br><span class="line">[{<span class="string">'label'</span>: <span class="string">'NEGATIVE'</span>, <span class="string">'score'</span>: <span class="number">0.9991129040718079</span>}]</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(nlp(<span class="string">"I love you"</span>))</span><br><span class="line">[{<span class="string">'label'</span>: <span class="string">'POSITIVE'</span>, <span class="string">'score'</span>: <span class="number">0.9998656511306763</span>}]</span><br></pre></td></tr></table></figure><h1 id="2-指定预训练模型"><a href="#2-指定预训练模型" class="headerlink" title="2. 指定预训练模型"></a>2. 指定预训练模型</h1><p>这里我们以 Bert 为例。</p><h2 id="2-1-配置-Bert-模型(可选,推荐不使用)transformers-BertConfig"><a href="#2-1-配置-Bert-模型(可选,推荐不使用)transformers-BertConfig" class="headerlink" title="2.1 配置 Bert 模型(可选,推荐不使用)transformers.BertConfig"></a>2.1 配置 Bert 模型(可选,推荐不使用)<code>transformers.BertConfig</code></h2><p><code>transformers.BertConfig</code> 可以自定义 Bert 模型的结构,以下参数都是可选的:</p><ul><li><code>vocab_size</code>:词汇数,默认 30522;</li><li><code>hidden_size</code>:编码器内隐藏层神经元数量,默认 768;</li><li><code>num_hidden_layers</code>:编码器内隐藏层层数,默认 12;</li><li><code>num_attention_heads</code>:编码器内注意力头数,默认 12;</li><li><code>intermediate_size</code>:编码器内全连接层的输入维度,默认 3072;</li><li><code>hidden_act</code>:编码器内激活函数,默认 ‘gelu’,还可为 ‘relu’、’swish’ 或 ‘gelu_new’</li><li><code>hidden_dropout_prob</code>:词嵌入层或编码器的 dropout,默认为 0.1;</li><li><code>attention_probs_dropout_prob</code>:注意力的 dropout,默认为 0.1;</li><li><code>max_position_embeddings</code>:模型使用的最大序列长度,默认为 512;</li><li><code>type_vocab_size</code>:词汇表类别,默认为 2;</li><li><code>initializer_range</code>:神经元权重的标准差,默认为 0.02;</li><li><code>layer_norm_eps</code>:layer normalization 的 epsilon 值,默认为 1e-12.</li></ul><p>使用方法:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">configuration = BertConfig() <span class="comment"># 进行模型的配置,变量为空即使用默认参数</span></span><br><span class="line"></span><br><span class="line">model = BertModel(configuration) <span class="comment"># 使用自定义配置实例化 Bert 模型</span></span><br><span class="line"></span><br><span class="line">configuration = model.config <span class="comment"># 查看模型参数</span></span><br></pre></td></tr></table></figure><h2 id="2-2-分词-transformers-BertTokenizer"><a href="#2-2-分词-transformers-BertTokenizer" class="headerlink" title="2.2 分词 transformers.BertTokenizer"></a>2.2 分词 <code>transformers.BertTokenizer</code></h2><p>所有的 tokenizer 都继承自 <code>transformers.PreTrainedTokenizer</code> 基类,因此有共同的参数和方法实例化的参数有:</p><ul><li><code>model_max_length</code>:可选参数,最大输入长度,默认为 1e30;</li><li><code>padding_side</code>:可选参数,填充的方向,应为 ‘left’ 或 ‘right’;</li><li><code>bos_token</code>:可选参数,每句话的起始标记,默认为 ‘<BOS>‘;</li><li><code>eos_token</code>:可选参数,每句话的结束标记,默认为 ‘<EOS>‘;</li><li><code>unk_token</code>:可选参数,未知的标记,默认为 ‘<UNK>‘;</li><li><code>sep_token</code>:可选参数,分隔标记,默认为 ‘<SEP>‘;</li><li><code>pad_token</code>:可选参数,填充标记,默认为 ‘<PAD>‘;</li><li><code>cls_token</code>:可选参数,分类标记,默认为 ‘<CLS>‘;</li><li><code>mask_token</code>:可选参数,遮盖标记,默认为 ‘<MASK>‘。</li></ul><p>为了演示,我们先实例化一个 <code>BertTokenizer</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">tokenizer = BertTokenizer.from_pretrained(<span class="string">'bert-base-cased'</span>)</span><br></pre></td></tr></table></figure><p>常用的方法有:</p><ul><li><code>from_pretrained(model)</code>:载入预训练词汇表;</li><li><code>tokenizer.tokenize(str)</code>:分词;<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>tokenizer.tokenize(<span class="string">'Hello word!'</span>)</span><br><span class="line">[<span class="string">'Hello'</span>, <span class="string">'word'</span>, <span class="string">'!'</span>]</span><br></pre></td></tr></table></figure></li><li><code>encode(text, ...)</code>:将文本分词后编码为包含对应 id 的列表;<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>tokenizer.encode(<span class="string">'Hello word!'</span>)</span><br><span class="line">[<span class="number">101</span>, <span class="number">8667</span>, <span class="number">1937</span>, <span class="number">106</span>, <span class="number">102</span>]</span><br></pre></td></tr></table></figure></li><li><code>encode_plus(text, ...)</code>:将文本分词后创建一个包含对应 id,token 类型及是否遮盖的词典;<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">tokenizer.encode_plus(<span class="string">'Hello world!'</span>)</span><br><span class="line">{<span class="string">'input_ids'</span>: [<span class="number">101</span>, <span class="number">8667</span>, <span class="number">1937</span>, <span class="number">106</span>, <span class="number">102</span>], <span class="string">'token_type_ids'</span>: [<span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>, <span class="number">0</span>], <span class="string">'attention_mask'</span>: [<span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>]}</span><br></pre></td></tr></table></figure></li><li><code>convert_ids_to_tokens(ids, skip_special_tokens)</code>:将 id 映射为 token;<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>tokenizer.convert_ids_to_tokens(tokens)</span><br><span class="line">[<span class="string">'[CLS]'</span>, <span class="string">'Hello'</span>, <span class="string">'word'</span>, <span class="string">'!'</span>, <span class="string">'[SEP]'</span>]</span><br></pre></td></tr></table></figure></li><li><code>decode(token_ids)</code>:将 id 解码;<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>tokenizer.decode(tokens)</span><br><span class="line"><span class="string">'[CLS] Hello word! [SEP]'</span></span><br></pre></td></tr></table></figure></li><li><code>convert_tokens_to_ids(tokens)</code>:将 token 映射为 id。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>tokenizer.convert_tokens_to_ids([<span class="string">'[CLS]'</span>, <span class="string">'Hello'</span>, <span class="string">'word'</span>, <span class="string">'!'</span>, <span class="string">'[SEP]'</span>])</span><br><span class="line">[<span class="number">101</span>, <span class="number">8667</span>, <span class="number">1937</span>, <span class="number">106</span>, <span class="number">102</span>]</span><br></pre></td></tr></table></figure><h2 id="2-3-使用预训练模型"><a href="#2-3-使用预训练模型" class="headerlink" title="2.3 使用预训练模型"></a>2.3 使用预训练模型</h2>根据任务的需要,既可以选择没有为指定任务 finetune 的模型如 <code>transformers.BertModel</code>,也可以选择为指定任务 finetune 之后的模型如 <code>transformers.BertForSequenceClassification</code>。一共有 6 个指定的任务类型:</li><li><code>transformers.BertForMaskedLM</code>:语言模型;</li><li><code>transformers.BertForNextSentencePrediction</code>:判断下一句话是否与上一句有关;</li><li><code>transformers.BertForSequenceClassification</code>:序列分类如 GLUE;</li><li><code>transformers.BertForMultipleChoice</code>:文本分类;</li><li><code>transformers.BertForTokenClassification</code>:token 分类如 NER,</li><li><code>transformers.BertForQuestionAnswering</code>;问答。</li></ul><h1 id="3-使用-AutoModels"><a href="#3-使用-AutoModels" class="headerlink" title="3. 使用 AutoModels"></a>3. 使用 <code>AutoModels</code></h1><p>使用 <code>AutoModels</code> 与上面的指定模型进行预训练大同小异,只不过是另一种方式加载模型而已。</p><h2 id="3-1-加载自动配置-transformers-AutoConfig"><a href="#3-1-加载自动配置-transformers-AutoConfig" class="headerlink" title="3.1 加载自动配置 transformers.AutoConfig"></a>3.1 加载自动配置 <code>transformers.AutoConfig</code></h2><p>使用类方法 <code>from_pretrained</code> 加载模型配置,参数既可以为模型名称,也可以为具体文件。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">config = AutoConfig.from_pretrained(<span class="string">'bert-base-uncased'</span>)</span><br><span class="line"><span class="comment"># 或者直接加载模型文件</span></span><br><span class="line">config = AutoConfig.from_pretrained(<span class="string">'./test/bert_saved_model/'</span>)</span><br></pre></td></tr></table></figure><h2 id="3-2-加载分词器-transformers-AutoTokenizer"><a href="#3-2-加载分词器-transformers-AutoTokenizer" class="headerlink" title="3.2 加载分词器 transformers.AutoTokenizer"></a>3.2 加载分词器 <code>transformers.AutoTokenizer</code></h2><p>与上面的 <code>BertTokenizer</code> 非常相似,也是使用 <code>from_pretrained</code> 类方法加载预训练模型。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">tokenizer = AutoTokenizer.from_pretrained(<span class="string">'bert-base-uncased'</span>)</span><br><span class="line"><span class="comment"># 或者直接加载模型文件</span></span><br><span class="line">tokenizer = AutoTokenizer.from_pretrained(<span class="string">'./test/bert_saved_model/'</span>)</span><br></pre></td></tr></table></figure><h2 id="3-3-加载模型-transformers-AutoModel"><a href="#3-3-加载模型-transformers-AutoModel" class="headerlink" title="3.3 加载模型 transformers.AutoModel"></a>3.3 加载模型 <code>transformers.AutoModel</code></h2><p>可以使用 <code>from_pretrained</code> 加载预训练模型:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">model = AutoModel.from_pretrained(<span class="string">'bert-base-uncased'</span>)</span><br><span class="line"><span class="comment"># 或者直接加载模型文件</span></span><br><span class="line">model = AutoModel.from_pretrained(<span class="string">'./test/bert_model/'</span>)</span><br></pre></td></tr></table></figure><p>选好了预训练模型以后,只需要给模型接一个全连接层,这个神经网络就搭好了(当然可以根据需要添加更复杂的结构)。是不是香?</p><p>欢迎关注我的微信公众号“花解语 NLP”:<br><img src="https://img-blog.csdnimg.cn/20200514100635366.jpg#pic_center" alt=""></p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 15:TorchText</title>
<link href="2020/06/04/DL-PyTorch-%E6%8A%98%E6%A1%82-15%EF%BC%9ATorchText/"/>
<url>2020/06/04/DL-PyTorch-%E6%8A%98%E6%A1%82-15%EF%BC%9ATorchText/</url>
<content type="html"><![CDATA[<p><code>TorchText</code> 是 PyTorch 的一个功能包,主要提供文本数据读取、创建迭代器的的功能与语料库、词向量的信息,分别对应了 <code>torchtext.data</code>、<code>torchtext.datasets</code> 和 <code>torchtext.vocab</code> 三个子模块。本文参考了三篇文章<a href="https://www.jianshu.com/p/0f7107db2f3a" target="_blank" rel="noopener" title="TorchText学习总结"></a><a href="https://zhuanlan.zhihu.com/p/94941514" target="_blank" rel="noopener" title="使用pytorch和torchtext进行文本分类"></a><a href="https://www.jianshu.com/p/71176275fdc5" target="_blank" rel="noopener" title="Torchtext使用教程"></a>。 </p><a id="more"></a><h1 id="1-语料库-torchtext-datasets"><a href="#1-语料库-torchtext-datasets" class="headerlink" title="1. 语料库 torchtext.datasets"></a>1. 语料库 <code>torchtext.datasets</code></h1><p><code>TorchText</code> 内建的语料库有:</p><ul><li>Language Modeling<ul><li>WikiText-2</li><li>WikiText103</li><li>PennTreebank</li></ul></li><li>Sentiment Analysis<ul><li>SST</li><li>IMDb</li></ul></li><li>Text Classification<ul><li>TextClassificationDataset</li><li>AG_NEWS</li><li>SogouNews</li><li>DBpedia</li><li>YelpReviewPolarity</li><li>YelpReviewFull</li><li>YahooAnswers</li><li>AmazonReviewPolarity</li><li>AmazonReviewFull</li></ul></li><li>Question Classification<ul><li>TREC</li></ul></li><li>Entailment<ul><li>SNLI</li><li>MultiNLI</li></ul></li><li>Machine Translation<ul><li>Multi30k</li><li>IWSLT</li><li>WMT14</li></ul></li><li>Sequence Tagging<ul><li>UDPOS</li><li>CoNLL2000Chunking</li></ul></li><li>Question Answering<ul><li>BABI20</li></ul></li><li>Unsupervised Learning<ul><li>EnWik9</li></ul></li></ul><h1 id="2-预训练的词向量-torchtext-vocab"><a href="#2-预训练的词向量-torchtext-vocab" class="headerlink" title="2. 预训练的词向量 torchtext.vocab"></a>2. 预训练的词向量 <code>torchtext.vocab</code></h1><p><code>TorchText</code> 内建的预训练词向量有:</p><ul><li>charngram.100d </li><li>fasttext.en.300d </li><li>fasttext.simple.300d </li><li>glove.42B.300d </li><li>glove.840B.300d </li><li>glove.twitter.27B.25d </li><li>glove.twitter.27B.50d </li><li>glove.twitter.27B.100d </li><li>glove.twitter.27B.200d </li><li>glove.6B.50d </li><li>glove.6B.100d </li><li>glove.6B.200d </li><li>glove.6B.300d</li></ul><h1 id="3-数据读取、数据框的创建-torchtext-data"><a href="#3-数据读取、数据框的创建-torchtext-data" class="headerlink" title="3. 数据读取、数据框的创建 torchtext.data"></a>3. 数据读取、数据框的创建 <code>torchtext.data</code></h1><h2 id="3-1-创建-Field"><a href="#3-1-创建-Field" class="headerlink" title="3.1 创建 Field"></a>3.1 创建 <code>Field</code></h2><p><code>Field</code> 可以理解为一个告诉 TorchText 如何处理字段的声明。</p><p><code>torchtext.data.Field(sequential=True, use_vocab=True, init_token=None, eos_token=None, fix_length=None, dtype=torch.int64, preprocessing=None, postprocessing=None, lower=False, tokenize=None, tokenizer_language='en', include_lengths=False, batch_first=False, pad_token='<pad>', unk_token='<unk>', pad_first=False, truncate_first=False, stop_words=None, is_target=False)</code></p><p>参数很多,这里仅仅介绍主要参数:</p><ul><li><code>sequential</code>:是否为已经被序列化的数据,默认为 True;</li><li><code>use_vocab</code>:是否应用词汇表。若为 False 则数据应该已经是数字形式,默认为 True;</li><li><code>init_token</code>:序列开头填充的 token,默认为 None 即不填充;</li><li><code>eos_token</code>:序列结尾填充的 token,默认为 None 即不填充;</li><li><code>lower</code>:是否将文本转换为小写,默认为 False;</li><li><code>tokenize</code>:分词器,默认为 <code>string.split</code>;</li><li><code>batch_first</code>:batch 是否在第一维上;</li><li><code>pad_token</code>:填充的 token,默认为 “<pad>“;</li><li><code>unk_token</code>:词汇表以外的词汇的表示,默认为 “<unk>“;</li><li><code>pad_first</code>:是否在序列的开头进行填充;默认为 False;</li><li><code>truncate_first</code>:是否在序列的开头将序列超过规定长度的部分进行截断;默认为 False;</li><li><code>stop_words</code>:是否过滤停用词,默认为 False;</li><li><code>is_target</code>:这个 <code>Field</code> 是否为标签,默认为 False。</li></ul><p><code>tokenize</code> 可以使用 SpaCy 的分词功能,使用以前要先构建分词功能:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> spacy</span><br><span class="line">spacy_en = spacy.load(<span class="string">'en'</span>)</span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">tokenizer</span><span class="params">(text)</span>:</span></span><br><span class="line"><span class="keyword">return</span> [token <span class="keyword">for</span> toekn <span class="keyword">in</span> spacy_en.tokenizer(text)]</span><br></pre></td></tr></table></figure><p><code>spacy</code> 分词的效果比原生的 <code>split</code> 函数好一点,但是速度也慢一些。然后可以创建对应文本的 Field 了:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">TEXT = data.Field(sequential=<span class="literal">True</span>, tokenize=tokenizer, lower=<span class="literal">True</span>) <span class="comment"># 假设文本为 raw data</span></span><br><span class="line">LABEL = data.Field(sequential=<span class="literal">False</span>, use_vocab=<span class="literal">False</span>) <span class="comment"># 假设标签为离散的数字变量</span></span><br></pre></td></tr></table></figure><h2 id="3-2-创建-Dataset"><a href="#3-2-创建-Dataset" class="headerlink" title="3.2 创建 Dataset"></a>3.2 创建 Dataset</h2><p>如果文本数据保存在 <code>csv</code>、<code>tsv</code> 或 <code>json</code> 文件中,我们优先使用 <code>torchtext.data.TabularDataset</code> 进行读取。</p><p><code>torchtext.data.TabularDataset(path, format, fields, skip_header=False, csv_reader_params={}, **kwargs)</code></p><ul><li><code>path</code>:数据的路径;</li><li><code>format</code>:文件的格式,为 <code>csv</code>、<code>tsv</code> 或 <code>json</code>;</li><li><code>fields</code>:上面已经定义好的 Field;</li><li><code>skip_header</code>:是否跳过第一行;</li><li><code>csv_reader_params</code>:当文件为 <code>csv</code> 或 <code>tsv</code> 时,可以自定义文件的格式。 </li></ul><p>例子:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">train, val = data.TabularDataset.splits(</span><br><span class="line"> path=<span class="string">'.'</span>, train=<span class="string">'train.csv'</span>,validation=<span class="string">'val.csv'</span>, format=<span class="string">'csv'</span>,skip_header=<span class="literal">True</span>,</span><br><span class="line"> fields=[(<span class="string">'PhraseId'</span>,<span class="literal">None</span>),(<span class="string">'SentenceId'</span>,<span class="literal">None</span>),(<span class="string">'Phrase'</span>, TEXT), (<span class="string">'Sentiment'</span>, LABEL)])</span><br><span class="line"></span><br><span class="line">test = data.TabularDataset(<span class="string">'test.tsv'</span>,</span><br><span class="line"> format=<span class="string">'tsv'</span>,skip_header=<span class="literal">True</span>,</span><br><span class="line"> fields=[(<span class="string">'PhraseId'</span>,<span class="literal">None</span>),(<span class="string">'SentenceId'</span>,<span class="literal">None</span>),(<span class="string">'Phrase'</span>, TEXT)])</span><br></pre></td></tr></table></figure><p>上面的例子说,<code>'PhraseId'</code> 和 <code>'SentenceId'</code> 不读取(<code>Field</code> 为 <code>None</code>),<code>'Phrase'</code> 以 <code>TEXT</code> 的方式进行读取,<code>'Sentiment'</code> 以 <code>LABEL</code> 的方式进行读取。</p><h2 id="3-3-建立词汇表"><a href="#3-3-建立词汇表" class="headerlink" title="3.3 建立词汇表"></a>3.3 建立词汇表</h2><p>现在我们需要将词转化为数字,并在模型中载入预训练好的词向量。词汇表存储在之前声明好的 <code>Field</code> 里面。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line">TEXT.build_vocab(train_data, <span class="comment"># 建词表是用训练集建,不要用验证集和测试集</span></span><br><span class="line"> max_size=<span class="number">400000</span>, <span class="comment"># 单词表容量</span></span><br><span class="line"> vectors=<span class="string">'glove.6B.300d'</span>, <span class="comment"># 还有'glove.840B.300d'已经很多可以选</span></span><br><span class="line"> unk_init=torch.init.xavier_uniform <span class="comment"># 初始化train_data中不存在预训练词向量词表中的单词</span></span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 在神经网络里加载词向量</span></span><br><span class="line">pretrained_embeddings = TEXT.vocab.vectors</span><br><span class="line">model.embedding.weight.data.copy_(pretrained_embeddings)</span><br><span class="line">UNK_IDX = REVIEW.vocab.stoi[REVIEW.unk_token]</span><br><span class="line">PAD_IDX = REVIEW.vocab.stoi[REVIEW.pad_token]</span><br><span class="line"><span class="comment"># 因为预训练的权重的unk和pad的词向量不是在我们的数据集语料上训练得到的,所以最好置零</span></span><br><span class="line">model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)</span><br><span class="line">model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)</span><br></pre></td></tr></table></figure><h2 id="3-4-创建迭代器"><a href="#3-4-创建迭代器" class="headerlink" title="3.4 创建迭代器"></a>3.4 创建迭代器</h2><p>迭代器推荐使用 <code>BucketIterator</code>,因为它会将文本中长度相似的序列尽量放在同一个 batch 里,减少 padding,从而减少计算量,加速计算。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">torchtext.data.BucketIterator(dataset, batch_size, sort_key=<span class="literal">None</span>, device=<span class="literal">None</span>, batch_size_fn=<span class="literal">None</span>, train=<span class="literal">True</span>, repeat=<span class="literal">False</span>, shuffle=<span class="literal">None</span>, sort=<span class="literal">None</span>, sort_within_batch=<span class="literal">None</span>)</span><br></pre></td></tr></table></figure><ul><li><code>dataset</code>:目标数据;</li><li><code>batch_size</code>:batch 的大小;</li><li><code>sort_key</code>:排序的方式默认为 None;</li><li><code>device</code>:载入的设备,默认为 CPU;</li><li><code>batch_size_fn</code>:取 batch 的函数,默认为 None;</li><li><code>train</code>:是否为训练集,默认为 True;</li><li><code>repeat</code>:在不同的 epoch 中是否重复相同的 iterater,默认为 False;</li><li><code>shuffle</code>:在不同的 epoch 中是否打乱数据的顺序,默认为 None;</li><li><code>sort</code>:是否根据 <code>sort_key</code> 对数据进行排序,默认为 None;</li><li><code>sort_within_batch</code>:是否根据 <code>sort_key</code> 对每个 batch 内的数据进行降序排序。</li></ul><p>举例:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">train_iter, val_iter = data.BucketIterator.split((train, val), batch_size=<span class="number">128</span>, sort_key=<span class="keyword">lambda</span> x: len(x.Phrase), </span><br><span class="line"> shuffle=<span class="literal">True</span>,device=DEVICE)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 在 test_iter , sort一定要设置成 False, 要不然会被 torchtext 搞乱样本顺序</span></span><br><span class="line">test_iter = data.Iterator(dataset=test, batch_size=<span class="number">128</span>, train=<span class="literal">False</span>,</span><br><span class="line"> sort=<span class="literal">False</span>, device=DEVICE)</span><br></pre></td></tr></table></figure><p>欢迎关注我的微信公众号“花解语 NLP”:<br><img src="https://img-blog.csdnimg.cn/20200514100635366.jpg#pic_center" alt=""></p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 14:其它功能</title>
<link href="2020/06/04/DL-PyTorch-%E6%8A%98%E6%A1%82-14%EF%BC%9A%E5%85%B6%E5%AE%83%E5%8A%9F%E8%83%BD/"/>
<url>2020/06/04/DL-PyTorch-%E6%8A%98%E6%A1%82-14%EF%BC%9A%E5%85%B6%E5%AE%83%E5%8A%9F%E8%83%BD/</url>
<content type="html"><![CDATA[<p>本以为 PyTorch 的文章要写两个月,结果发现 PyTorch 真的太轻了,写了不到一个月就写完了。本篇为完结篇,对一些零星的功能进行总结。</p><a id="more"></a><h1 id="1-torch-nn-utils-里的一些功能"><a href="#1-torch-nn-utils-里的一些功能" class="headerlink" title="1. torch.nn.utils 里的一些功能"></a>1. <code>torch.nn.utils</code> 里的一些功能</h1><h2 id="1-1-梯度剪枝"><a href="#1-1-梯度剪枝" class="headerlink" title="1.1 梯度剪枝"></a>1.1 梯度剪枝</h2><p>在 <a href="https://mp.weixin.qq.com/s?__biz=Mzg3OTIwODUzMQ==&mid=2247484983&idx=1&sn=814762bcc217f57ae9507083875963ca&chksm=cf06b70bf8713e1d01b79eaa0cebad4acc66ea622c027f6110d3662fc6844f7435ba62cb93d1&token=885036689&lang=zh_CN#rd" target="_blank" rel="noopener">PyTorch 折桂 8:torch.nn.init</a> 里提到过梯度爆炸的问题,当时我们的解决方法是对神经元权重的初始化进行控制,这里再介绍一个简单粗暴的方式:直接限制权重的上限。对应导数超过上限的权重,将其导数重置为上限值。</p><p><code>torch.nn.utils.clip_grad_value_(parameters, clip_value)</code></p><ul><li><code>parameters</code>:需要修改的权重</li><li><code>clip_value</code>:权重的上限<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>weight = torch.tensor((<span class="number">10.</span>), requires_grad=<span class="literal">True</span>)</span><br><span class="line"><span class="meta">>>> </span>relu = nn.ReLU() <span class="comment"># 对大于 0 的值,ReLU 的处理结果为 1.</span></span><br><span class="line"><span class="meta">>>> </span>out = relu(weight)</span><br><span class="line"><span class="meta">>>> </span>weight.backward() <span class="comment"># 反向传播</span></span><br><span class="line"><span class="meta">>>> </span>nn.utils.clip_grad_value_(weight, <span class="number">0.5</span>) <span class="comment"># 梯度从 1.0 被限制为 0.5</span></span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(weight.grad)</span><br><span class="line">tensor(<span class="number">0.5000</span>)</span><br></pre></td></tr></table></figure></li></ul><p>可以看到,<code>clip_grad_value_</code> 为 inplace 操作,需要在 <code>Tensor.backward</code> 与 <code>optimizer.step</code> 之间使用。</p><p>除此以外,还有一个 <code>torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=2)</code> 函数,将若干个权重修改为服从正态分布的范围,这里不多赘述。</p><h2 id="1-2-PyTorch-对可变长度序列的处理"><a href="#1-2-PyTorch-对可变长度序列的处理" class="headerlink" title="1.2 PyTorch 对可变长度序列的处理"></a>1.2 PyTorch 对可变长度序列的处理</h2><p>在 NLP 任务中,我们经常要处理不定长度的序列。PyTorch 提供了将不定长度的序列进行打包的函数。</p><ul><li><code>torch.nn.utils.rnn.pad\_sequence(sequences, batch\_first=False, padding\_value=0)</code></li></ul><p>将不定长序列补全至最长序列的长度。接受三个参数:</p><ul><li><code>sequences</code>:接受补全的序列;</li><li><code>batch\_first</code>:批是否为第一个维度,默认为 False;</li><li><code>padding\_value</code>:填充的值,默认为 0。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.ones(<span class="number">25</span>, <span class="number">300</span>)</span><br><span class="line"><span class="meta">>>> </span>b = torch.ones(<span class="number">22</span>, <span class="number">300</span>)</span><br><span class="line"><span class="meta">>>> </span>c = torch.ones(<span class="number">15</span>, <span class="number">300</span>)</span><br><span class="line"><span class="meta">>>> </span>torch.nn.utils.rnn.pad_sequence([a, b, c]).size()</span><br><span class="line">torch.Size([<span class="number">25</span>, <span class="number">3</span>, <span class="number">300</span>])</span><br></pre></td></tr></table></figure></li><li><code>torch.nn.utils.rnn.pack\_sequence(sequences, enforce\_sorted=True)</code></li></ul><p>将序列直接打包成一个 <code>PackedSequence</code> 实例。有两个参数:</p><ul><li><code>sequences</code>:要打包的序列;</li><li><code>enforce\_sorted</code>:若为 True,则将序列以长度的降序进行排列,默认为 True。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([<span class="number">1</span>,<span class="number">2</span>,<span class="number">3</span>])</span><br><span class="line"><span class="meta">>>> </span>b = torch.tensor([<span class="number">4</span>,<span class="number">5</span>])</span><br><span class="line"><span class="meta">>>> </span>c = torch.tensor([<span class="number">6</span>])</span><br><span class="line"><span class="meta">>>> </span>torch.nn.utils.rnn.pack_sequence([a, b, c], enforce_sorted=<span class="literal">False</span>)</span><br><span class="line">PackedSequence(data=tensor([<span class="number">1</span>, <span class="number">4</span>, <span class="number">6</span>, <span class="number">2</span>, <span class="number">5</span>, <span class="number">3</span>]), batch_sizes=tensor([<span class="number">3</span>, <span class="number">2</span>, <span class="number">1</span>]), sorted_indices=tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>]), unsorted_indices=tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>]))</span><br></pre></td></tr></table></figure></li></ul><p>除了 padding 与裁剪将所有序列统一为定长以外,PyTorch 还提供了两个函数将不定长度序列打包和解包。</p><ul><li><code>torch.nn.utils.rnn.pack\_padded\_sequence(input, lengths, batch\_first=False, enforce\_sorted=True)</code></li></ul><p>将一个不定长度的序列进行打包,返回一个 <code>PackedSequence</code> 实例。有 4 个参数:</p><ul><li><code>input</code>:一个 <code>T x B x *</code> 尺寸的序列,<code>T</code> 为序列中最长的序列的长度,<code>B</code> 为 batch 的数量,<code>*</code> 为每个序列的维度(可以为 0);</li><li><code>lengths</code>:单个序列长度的列表;</li><li><code>batch\_first</code>:是否以 batch 为第一个维度;</li><li><code>enforce\_sorted</code>:是否对序列以每个序列的长度进行降序排序。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>seq = torch.tensor([[<span class="number">1</span>,<span class="number">2</span>,<span class="number">0</span>], [<span class="number">3</span>,<span class="number">0</span>,<span class="number">0</span>], [<span class="number">4</span>,<span class="number">5</span>,<span class="number">6</span>]])</span><br><span class="line"><span class="meta">>>> </span>lens = [<span class="number">2</span>, <span class="number">1</span>, <span class="number">3</span>]</span><br><span class="line"><span class="meta">>>> </span>packed = torch.nn.utils.rnn.pack_padded_sequence(seq, lens, batch_first=<span class="literal">True</span>, enforce_sorted=<span class="literal">False</span>)</span><br><span class="line"><span class="meta">>>> </span>packed</span><br><span class="line">PackedSequence(data=tensor([<span class="number">4</span>, <span class="number">1</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">2</span>, <span class="number">6</span>]), batch_sizes=tensor([<span class="number">3</span>, <span class="number">2</span>, <span class="number">1</span>]),</span><br><span class="line"> sorted_indices=tensor([<span class="number">2</span>, <span class="number">0</span>, <span class="number">1</span>]), unsorted_indices=tensor([<span class="number">1</span>, <span class="number">2</span>, <span class="number">0</span>]))</span><br></pre></td></tr></table></figure></li></ul><p>还有一个与之相反的解包函数:</p><ul><li><code>torch.nn.utils.rnn.pad\_packed\_sequence(sequence, batch\_first=False, padding\_value=0.0, total\_length=None)</code></li></ul><p>这个函数接受一个 <code>PackedSequence</code> 实例,有 4 个参数:</p><ul><li><code>sequence</code>:需要进行解包的序列;</li><li><code>batch\_first</code>:是否以 batch 为第一维;</li><li><code>padding\_value</code>:解包后填充的值,默认为 0;</li><li><code>total\_length</code>:将所有序列填充至 <code>total\_length</code> 的长度。如果这个值小于最长序列的长度,将抛出异常。</li></ul><p>这个函数返回两个张量,解包后的序列和原始序列的长度。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>seq_unpacked, lens_unpacked = torch.nn.utils.rnn.pad_packed_sequence(packed, batch_first=<span class="literal">True</span>)</span><br><span class="line"><span class="meta">>>> </span>seq_unpacked</span><br><span class="line">tensor([[<span class="number">1</span>, <span class="number">2</span>, <span class="number">0</span>],</span><br><span class="line"> [<span class="number">3</span>, <span class="number">0</span>, <span class="number">0</span>],</span><br><span class="line"> [<span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>]])</span><br><span class="line"><span class="meta">>>> </span>lens_unpacked</span><br><span class="line">tensor([<span class="number">2</span>, <span class="number">1</span>, <span class="number">3</span>])</span><br></pre></td></tr></table></figure><h1 id="2-GPU-的使用"><a href="#2-GPU-的使用" class="headerlink" title="2. GPU 的使用"></a>2. GPU 的使用</h1><h2 id="2-1-检查系统内-GPU-的状态"><a href="#2-1-检查系统内-GPU-的状态" class="headerlink" title="2.1 检查系统内 GPU 的状态"></a>2.1 检查系统内 GPU 的状态</h2><p>可以使用 <code>nvidia-smi</code> 命令。在 Jupyter Notebook 里要在命令前加上 <code>!</code>。下面为 Google Colab 上的 GPU 状态:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>!nvidia-smi</span><br><span class="line">Thu Jun <span class="number">4</span> <span class="number">16</span>:<span class="number">47</span>:<span class="number">42</span> <span class="number">2020</span> </span><br><span class="line">+-----------------------------------------------------------------------------+</span><br><span class="line">| NVIDIA-SMI <span class="number">440.82</span> Driver Version: <span class="number">418.67</span> CUDA Version: <span class="number">10.1</span> |</span><br><span class="line">|-------------------------------+----------------------+----------------------+</span><br><span class="line">| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |</span><br><span class="line">| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |</span><br><span class="line">|===============================+======================+======================|</span><br><span class="line">| <span class="number">0</span> Tesla K80 Off | <span class="number">00000000</span>:<span class="number">00</span>:<span class="number">04.0</span> Off | <span class="number">0</span> |</span><br><span class="line">| N/A <span class="number">69</span>C P8 <span class="number">33</span>W / <span class="number">149</span>W | <span class="number">11</span>MiB / <span class="number">11441</span>MiB | <span class="number">0</span>% Default |</span><br><span class="line">+-------------------------------+----------------------+----------------------+</span><br><span class="line"> </span><br><span class="line">+-----------------------------------------------------------------------------+</span><br><span class="line">| Processes: GPU Memory |</span><br><span class="line">| GPU PID Type Process name Usage |</span><br><span class="line">|=============================================================================|</span><br><span class="line">| No running processes found |</span><br><span class="line">+-----------------------------------------------------------------------------+</span><br></pre></td></tr></table></figure><h2 id="2-2-在-PyTorch-内检查可用-GPU"><a href="#2-2-在-PyTorch-内检查可用-GPU" class="headerlink" title="2.2 在 PyTorch 内检查可用 GPU"></a>2.2 在 PyTorch 内检查可用 GPU</h2><p>可以使用 <code>torch.cuda.is\_available()</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.cuda.is_available()</span><br><span class="line"><span class="literal">True</span></span><br></pre></td></tr></table></figure><h2 id="2-3-将神经网络与张量在-CPU-与-GPU-之间移动"><a href="#2-3-将神经网络与张量在-CPU-与-GPU-之间移动" class="headerlink" title="2.3 将神经网络与张量在 CPU 与 GPU 之间移动"></a>2.3 将神经网络与张量在 CPU 与 GPU 之间移动</h2><p>有两种方法:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 第一种方法</span></span><br><span class="line">x = x.cuda() <span class="comment"># 将 x 移动到 GPU 上</span></span><br><span class="line">x = x.cpu() <span class="comment"># 将 x 移动到 CPU 上</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 第二种方法</span></span><br><span class="line">x = x.to(<span class="string">'cuda'</span>) <span class="comment"># 将 x 移动到 GPU 上</span></span><br><span class="line">x = x.to(<span class="string">'cpu'</span>) <span class="comment"># 将 x 移动到 CPU 上</span></span><br></pre></td></tr></table></figure><p>以上仅为一张 GPU 的情况。GPU 上仅可以进行运算,其它操作需要将张量移动到 CPU 上完成。</p><h1 id="3-模型的保存与读取"><a href="#3-模型的保存与读取" class="headerlink" title="3. 模型的保存与读取"></a>3. 模型的保存与读取</h1><p>保存的模型如果在 GPU 上,需要先转移到 CPU 上。保存模型既可以保存整个模型,也可以只保存模型的参数权重。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 保存整个模型</span></span><br><span class="line">torch.save(the_model, PATH)</span><br><span class="line"><span class="comment"># 只保存模型的参数</span></span><br><span class="line">torch.save(the_model.state_dict(), PATH)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 读取整个模型</span></span><br><span class="line">the_model = torch.load(PATH)</span><br><span class="line"><span class="comment"># 只读取模型的权重</span></span><br><span class="line">the_model.load_state_dict(torch.load(PATH))</span><br></pre></td></tr></table></figure><p>欢迎关注我的微信公众号“花解语 NLP”:<br><img src="https://img-blog.csdnimg.cn/20200514100635366.jpg#pic_center" alt=""></p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 13:RNN</title>
<link href="2020/05/30/DL-PyTorch-%E6%8A%98%E6%A1%82-13%EF%BC%9ARNN/"/>
<url>2020/05/30/DL-PyTorch-%E6%8A%98%E6%A1%82-13%EF%BC%9ARNN/</url>
<content type="html"><![CDATA[<p>RNN(recurrent neural network)擅长处理序列内容,因此在 NLP 中应用较多。然而 RNN 的拓扑结构与 MLP、CNN 完全不同,因此学习起来会有很大的困扰。本文是介绍如何用锤子敲钉子的,而不是如何造锤子或者为什么要敲的。所以 RNN 的原理与使用场景在这里从略。然而了解 RNN 的工作原理对正确使用 RNN 大有裨益,所以在此附上参考资料<a href="https://zhuanlan.zhihu.com/p/32103001" target="_blank" rel="noopener" title="读PyTorch源码学习RNN(1)"></a><a href="https://zhuanlan.zhihu.com/p/80866196" target="_blank" rel="noopener" title="PyTorch 学习笔记(十一):循环神经网络(RNN)"></a><a href="https://zybuluo.com/hanbingtao/note/541458" target="_blank" rel="noopener" title="零基础入门深度学习(5) - 循环神经网络"></a> ,供读者参考。</p><a id="more"></a><p>RNN 主要有三个实现:原始 RNN 和 RNN 的改进版 LSTM 和 GRU。一个循环神经网络主要由输入层、隐藏层(RNN 层)、输出层构成,两层之间由激活函数相连。不像 MLP、CNN 那样多个隐藏层必须显式地写出来,RNN 的隐藏层可以以一个 RNN 的参数表示。所以 RNN 网络的格式是:<br>$$y=\alpha(RNN(x))$$<br>而 RNN、LSTM 和 GRU 的类也是大同小异:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line">torch.nn.RNN(input_size, hidden_size, num_layers=<span class="number">1</span>, nonlinearity=<span class="string">'tanh'</span>, bias=<span class="literal">True</span>, batch_first=<span class="literal">False</span>, dropout=<span class="number">0</span>, bidirectional=<span class="literal">False</span>)</span><br><span class="line">torch.nn.LSTM(input_size, hidden_size, num_layers=<span class="number">1</span>, bias=<span class="literal">True</span>, batch_first=<span class="literal">False</span>, dropout=<span class="number">0</span>, bidirectional=<span class="literal">False</span>)</span><br><span class="line">torch.nn.GRU(input_size, hidden_size, num_layers=<span class="number">1</span>, bias=<span class="literal">True</span>, batch_first=<span class="literal">False</span>, dropout=<span class="number">0</span>, bidirectional=<span class="literal">False</span>)</span><br></pre></td></tr></table></figure><p>可以看到,<code>torch.nn.RNN</code> 比其它两个类就多了一个参数 <code>nonlinearity</code>,这是因为 RNN 里的激活函数可以是 <code>tanh</code> 也可以说 <code>relu</code>,而另外两个类的激活函数已经定义好了。下面逐一说明一下:</p><ul><li><code>input_size</code>:输入 x 中的特征数;</li><li><code>hidden_size</code>:隐藏层的特征数;</li><li><code>num_layers</code>:隐藏层的数量;</li><li><code>bias</code>:是否有偏置项;</li><li><code>batch_first</code>:数据维度中批是否在第一项;</li><li><code>dropout</code>:是否有 dropout;</li><li><code>bidirectional</code>:RNN 是单向还是双向。</li></ul><p>RNN 实例接受的参数有两个:一个张量和上一次的隐藏层:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>rnn = RNN(input_size, hidden_size)</span><br><span class="line"><span class="meta">>>> </span>output, hidden_current = rnn(input, hidden_previous)</span><br></pre></td></tr></table></figure><p>RNN 的输出有两个,分别是输出值和当前的隐藏层。在 <code>batch_first=True</code> 的时候,当前的隐藏层的维度为 <code>(batch, seq_len, num_directions*hidden_size)</code>,而前一个隐藏层的维度为 <code>batch, num_layers*num_directions, hidden_size</code>。我们来看一个例子:我们首先创建一个接受维度为 <code>(1, 5, 2)</code>(每批一个数据点,每个数据点有 5 个特征,两个隐藏层)的 RNN 层,其它参数使用默认参数:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>rnn = torch.nn.LSTM(<span class="number">1</span>, <span class="number">5</span>, <span class="number">2</span>, batch_first=<span class="literal">True</span>)</span><br></pre></td></tr></table></figure><p>然后创建一个两批、每批 3 个数据点、每个数据点一个特征的张量:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.rand(<span class="number">2</span>, <span class="number">3</span>, <span class="number">1</span>)</span><br><span class="line"><span class="meta">>>> </span>print(a)</span><br><span class="line">tensor([[[<span class="number">0.9472</span>],</span><br><span class="line"> [<span class="number">0.1003</span>],</span><br><span class="line"> [<span class="number">0.7684</span>]],</span><br><span class="line"></span><br><span class="line"> [[<span class="number">0.8318</span>],</span><br><span class="line"> [<span class="number">0.7707</span>],</span><br><span class="line"> [<span class="number">0.2214</span>]]])</span><br></pre></td></tr></table></figure><p>将这个张量喂给 RNN:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>out, h = rnn(a)</span><br></pre></td></tr></table></figure><p>这里我们没有给 RNN 网络是一个隐藏层的数值,所以 RNN 自动创建了一个权重全为 0 的隐藏层。我们看一下输出:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>print(out.size())</span><br><span class="line">torch.Size([<span class="number">2</span>, <span class="number">3</span>, <span class="number">5</span>])</span><br><span class="line"><span class="meta">>>> </span>print(out)</span><br><span class="line">tensor([[[ <span class="number">0.0620</span>, <span class="number">0.0790</span>, <span class="number">-0.0028</span>, <span class="number">-0.1094</span>, <span class="number">0.1258</span>],</span><br><span class="line"> [ <span class="number">0.0840</span>, <span class="number">0.0963</span>, <span class="number">-0.0315</span>, <span class="number">-0.1287</span>, <span class="number">0.1837</span>],</span><br><span class="line"> [ <span class="number">0.0983</span>, <span class="number">0.1190</span>, <span class="number">-0.0491</span>, <span class="number">-0.1257</span>, <span class="number">0.2184</span>]],</span><br><span class="line"></span><br><span class="line"> [[ <span class="number">0.0612</span>, <span class="number">0.0764</span>, <span class="number">-0.0047</span>, <span class="number">-0.1101</span>, <span class="number">0.1244</span>],</span><br><span class="line"> [ <span class="number">0.0865</span>, <span class="number">0.1130</span>, <span class="number">-0.0228</span>, <span class="number">-0.1283</span>, <span class="number">0.1899</span>],</span><br><span class="line"> [ <span class="number">0.0992</span>, <span class="number">0.1151</span>, <span class="number">-0.0485</span>, <span class="number">-0.1235</span>, <span class="number">0.2183</span>]]],</span><br><span class="line"> grad_fn=<TransposeBackward0>)</span><br><span class="line"> </span><br><span class="line"><span class="meta">>>> </span>print(h[<span class="number">0</span>].size())</span><br><span class="line">torch.Size([<span class="number">2</span>, <span class="number">2</span>, <span class="number">5</span>])</span><br><span class="line"><span class="meta">>>> </span>print(h)</span><br><span class="line">(tensor([[[<span class="number">-0.0562</span>, <span class="number">-0.0368</span>, <span class="number">-0.1863</span>, <span class="number">-0.2322</span>, <span class="number">0.0921</span>],</span><br><span class="line"> [<span class="number">-0.0424</span>, <span class="number">-0.0347</span>, <span class="number">-0.1600</span>, <span class="number">-0.1809</span>, <span class="number">0.1258</span>]],</span><br><span class="line"> </span><br><span class="line"> [[ <span class="number">0.0983</span>, <span class="number">0.1190</span>, <span class="number">-0.0491</span>, <span class="number">-0.1257</span>, <span class="number">0.2184</span>],</span><br><span class="line"> [ <span class="number">0.0992</span>, <span class="number">0.1151</span>, <span class="number">-0.0485</span>, <span class="number">-0.1235</span>, <span class="number">0.2183</span>]]],</span><br><span class="line"> grad_fn=<StackBackward>),</span><br><span class="line"> tensor([[[<span class="number">-0.1437</span>, <span class="number">-0.0643</span>, <span class="number">-0.3578</span>, <span class="number">-0.3889</span>, <span class="number">0.1648</span>],</span><br><span class="line"> [<span class="number">-0.1044</span>, <span class="number">-0.0650</span>, <span class="number">-0.3243</span>, <span class="number">-0.3031</span>, <span class="number">0.2357</span>]],</span><br><span class="line"> </span><br><span class="line"> [[ <span class="number">0.1939</span>, <span class="number">0.1787</span>, <span class="number">-0.0983</span>, <span class="number">-0.2349</span>, <span class="number">0.3685</span>],</span><br><span class="line"> [ <span class="number">0.1932</span>, <span class="number">0.1733</span>, <span class="number">-0.0973</span>, <span class="number">-0.2295</span>, <span class="number">0.3687</span>]]],</span><br><span class="line"> grad_fn=<StackBackward>))</span><br></pre></td></tr></table></figure><p>为什么会是这样呢?模型和输入张量的维度分别为:</p><figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">LSTM(input_size, hidden_size, num_layer)</span><br><span class="line">1 5 2</span><br><span class="line">trnsor(batch (if 'batch_first=True'), seq_len, input_size)</span><br><span class="line">2 3 1</span><br></pre></td></tr></table></figure><p>输出张量的维度为:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">out.shape: <span class="number">2</span>, <span class="number">3</span>, <span class="number">5</span></span><br><span class="line"> batch, seq_len, num_directions*hidden_size</span><br><span class="line">hidden.shape: <span class="number">2</span>, <span class="number">2</span>, <span class="number">5</span></span><br><span class="line"> num_layers*num_directions, batch, hidden_zie</span><br></pre></td></tr></table></figure><p>是不是一目了然?这里要注意,RNN 在内部运算的时候,张量的维度是 <code>(inpu_size, batch, hidden_size)</code>,虽然我们设置 <code>batch_first=True</code> 将输入和输出的张量的 batch 放到了第一维,输入和输出的 hidden 的 batch 仍然在第二维。</p><p>RNN 的改进版 LSTM 和 GRU 的原理可以看这里 <a href="https://zhuanlan.zhihu.com/p/79064602" target="_blank" rel="noopener" title="LSTM细节分析理解(pytorch版)">1</a> <a href="https://www.zhihu.com/question/41949741/answer/318771336" target="_blank" rel="noopener" title="LSTM神经网络输入输出究竟是怎样的?">2</a> <a href="https://zhuanlan.zhihu.com/p/46981722" target="_blank" rel="noopener" title="超生动图解LSTM和GRU:拯救循环神经网络的记忆障碍,就靠它们了">3</a>。</p><p>欢迎关注我的微信公众号“花解语 NLP”:<br><img src="https://img-blog.csdnimg.cn/20200514100635366.jpg#pic_center" alt=""></p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 12:CNN</title>
<link href="2020/05/27/DL-PyTorch-%E6%8A%98%E6%A1%82-12%EF%BC%9ACNN/"/>
<url>2020/05/27/DL-PyTorch-%E6%8A%98%E6%A1%82-12%EF%BC%9ACNN/</url>
<content type="html"><![CDATA[<p> 本文尽量不涉及 CNN(卷积神经网络)的原理,仅讨论 CNN 的 PyTorch 实现。CNN 独有的层包括卷积层(convolution layer),池化层(pooling layer),转置卷积层(transposed convolution layer),反池化层(unpooling layer)。卷积层与池化层在 CNN 中最常用,而转置卷积层与反池化层通常用于计算机视觉应用里的图像再生,对于 NLP 来说应用不多,不再赘述。</p><a id="more"></a> <h1 id="1-卷积神经网络工作原理"><a href="#1-卷积神经网络工作原理" class="headerlink" title="1. 卷积神经网络工作原理"></a>1. 卷积神经网络工作原理</h1><p> 从工程实现的角度来说,一个 CNN 网络可以分成两部分:特征学习阶段与分类阶段。<br> <img src="https://img-blog.csdnimg.cn/20200515075546921.jpeg?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDYxNDY4Nw==,size_16,color_FFFFFF,t_70#pic_center" alt="A Comprehensive Guide to Convolutional Neural Networks — the ELI5 way"><br>特征学习层由多层卷积层与池化层叠加,之间使用 relu 作为激活函数。卷积层的作用是使信息变深(层数增加),通常会使层的长宽减小;池化层的作用是使信息变窄,提取主要信息。之后进入分类层,将信息变成一维向量,经过 1-3 层全连接层与 relu 之后,经过最终的 softmax 层进行分类;若目标为二分类,则也可以经过 sigmoid 层。</p><h1 id="2-convolution-layer-卷积层"><a href="#2-convolution-layer-卷积层" class="headerlink" title="2. convolution layer 卷积层"></a>2. convolution layer 卷积层</h1><p> 卷积层有三个类,分别是:</p><ul><li><code>torch.nn.Conv1d</code></li><li><code>torch.nn.Conv2d</code></li><li><code>torch.nn.Conv3d</code><br>这三个类分别对应了文本(一维数据)、图片(二维数据)和视频(三维数据)。它们的维度如下:<ul><li>一维数据是一个 3 维张量:batch * channel * feature;</li><li>二维数据是一个 4 维张量:batch * channel * weight * height;</li><li>三维数据是一个 5 维张量:batch * channel * frame * weight * height。</li></ul></li></ul><p>可见,三个类处理的数据的前两维是完全一致的。此外,三个类的参数也完全一致,以 <code>torch.nn.Conv2d</code> 为例:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=<span class="number">1</span>, padding=<span class="number">0</span>, dilation=<span class="number">1</span>, groups=<span class="number">1</span>, bias=<span class="literal">True</span>, padding_mode=<span class="string">'zeros'</span>)</span><br></pre></td></tr></table></figure><ul><li><code>in_channels</code>:输入张量的层数;</li><li><code>out_channels</code>:输出张量的层数;</li><li><code>kernel_size</code>:卷积核的大小,整数或元组;</li><li><code>stride</code>:卷积的步长,整数或元组;</li><li><code>padding</code>:填充的宽度,整数或元组;</li><li><code>dilation</code>:稀释的跨度,整数或元组;</li><li><code>groups</code>:卷积的分组;</li><li><code>bias</code>:偏置项;</li><li><code>padding_mode</code>:填充的方法。</li></ul><p>当所有尺寸均为矩形的时候,输出张量的长和宽的数值为:<br>$$dimension=\frac{H_{in}+2\times padding-dilution\times (kernel_size-1)-1}{stride}$$</p><ul><li><strong>一个 trick</strong>:当 $kernel_size=3$,$padding=1$,$stride=1$ 的时候,输入张量和输出张量的长宽是不变的。</li></ul><p>池化层的权重是随机初始化的,不过我们也可以手动设定。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>conv = torch.nn.Conv2d(<span class="number">1</span>, <span class="number">1</span>, <span class="number">3</span>, bias=<span class="number">0.</span>) <span class="comment"># 定义一个 3x3 的卷积核</span></span><br><span class="line"><span class="meta">>>> </span>nn.init.constant_(conv.weight.data, <span class="number">1.</span>) <span class="comment"># 卷积核的权重设为 1.</span></span><br><span class="line"><span class="meta">>>> </span>print(Convolutional.weight.data)</span><br><span class="line">tensor([[[[<span class="number">1.</span>, <span class="number">1.</span>, <span class="number">1.</span>],</span><br><span class="line"> [<span class="number">1.</span>, <span class="number">1.</span>, <span class="number">1.</span>],</span><br><span class="line"> [<span class="number">1.</span>, <span class="number">1.</span>, <span class="number">1.</span>]]]])</span><br><span class="line"><span class="meta">>>> </span>tensor = torch.linspace(<span class="number">16.</span>, <span class="number">1.</span>, <span class="number">16</span>).reshape(<span class="number">1</span>, <span class="number">1</span>, <span class="number">4</span>, <span class="number">4</span>) <span class="comment"># 定义一个张量</span></span><br><span class="line"><span class="meta">>>> </span>print(tensor)</span><br><span class="line">tensor([[[[<span class="number">16.</span>, <span class="number">15.</span>, <span class="number">14.</span>, <span class="number">13.</span>],</span><br><span class="line"> [<span class="number">12.</span>, <span class="number">11.</span>, <span class="number">10.</span>, <span class="number">9.</span>],</span><br><span class="line"> [ <span class="number">8.</span>, <span class="number">7.</span>, <span class="number">6.</span>, <span class="number">5.</span>],</span><br><span class="line"> [ <span class="number">4.</span>, <span class="number">3.</span>, <span class="number">2.</span>, <span class="number">1.</span>]]]])</span><br><span class="line"><span class="meta">>>> </span>conv(tensor) <span class="comment"># 卷积操作</span></span><br><span class="line">tensor([[[[<span class="number">99.</span>, <span class="number">90.</span>],</span><br><span class="line"> [<span class="number">63.</span>, <span class="number">54.</span>]]]], grad_fn=<MkldnnConvolutionBackward>)</span><br></pre></td></tr></table></figure><p>上例中,卷积核是一个 $3\times3$ 的全 1 张量;在卷积运算中,卷积核先与张量中前三排中的前三个元素进行 elementwise 的乘法,然后相加,得到输出张量中的第一个元素。然后向右滑动一个元素(因为 <code>stride</code> 默认是 1),重复卷积运算;既然达到末尾,返回左侧向下滑动一个单位,继续运算,直到到达末尾。</p><h1 id="3-pool-layer-池化层"><a href="#3-pool-layer-池化层" class="headerlink" title="3. pool layer 池化层"></a>3. pool layer 池化层</h1><p>与卷积层对应的,池化层分为最大池化和平均池化两种,每种也有三个类:</p><ul><li><code>torch.nn.MaxPool1d</code></li><li><code>torch.nn.MaxPool2d</code></li><li><code>torch.nn.MaxPool3d</code></li><li><code>torch.nn.AvgPool1d</code></li><li><code>torch.nn.AvgPool2d</code></li><li><code>torch.nn.AvgPool3d</code></li></ul><p>所谓“池化”,就是按照一定的规则(选取最大值或计算平均值)在输入层的窗口里计算数据,返回计算结果。它们的参数也一致,最大池化层只有三个参数:</p><ul><li><code>kernel_size</code>:卷积核的大小,整数或元组;</li><li><code>stride</code>:卷积的步长,整数或元组;</li><li><code>padding</code>:填充的宽度,整数或元组;</li></ul><p>一维平均池化层有额外的两个参数:</p><ul><li><code>ceil_mode</code>:对结果进行上取整;</li><li><code>count_include_pad</code>:是否将 padding 纳入计算;</li></ul><p>二维及三维平均池化层有额外的一个参数:</p><ul><li><p><code>divisor_override</code>:指定一个除数。</p></li><li><p><strong>一个 trick</strong>:当 $\text{kernel_size}=2$,$\text{stride}=2$ 的时候,输出张量的尺寸是输入张量的一半。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>pool = torch.nn.MaxPool2d(<span class="number">2</span>) <span class="comment"># 定义一个大小为 2x2 的核</span></span><br><span class="line"><span class="meta">>>> </span>pool(tensor) <span class="comment"># 池化操作</span></span><br><span class="line">tensor([[[[<span class="number">16.</span>, <span class="number">14.</span>],</span><br><span class="line"> [ <span class="number">8.</span>, <span class="number">6.</span>]]]])</span><br></pre></td></tr></table></figure><h1 id="4-CNN-实战"><a href="#4-CNN-实战" class="headerlink" title="4. CNN 实战"></a>4. CNN 实战</h1><p>我们还是使用<a href="">《[DL] PyTorch 折桂 11:使用全连接网络进行手写数字识别》</a> 里的任务,只不过这一次我们使用 CNN 搭建神经网络。除了第 2、5 步,其它代码都是一样的,所以这里只有这两步的代码,其它代码请看前文。</p></li></ul><p>构建神经网络时唯一要注意的是最后的全连接层的入度。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">CNN</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></span><br><span class="line"> super(CNN, self).__init__()</span><br><span class="line"> self.conv = nn.Conv2d(<span class="number">1</span>, <span class="number">4</span>, <span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>) <span class="comment"># 维度不变</span></span><br><span class="line"> self.pool = nn.MaxPool2d(<span class="number">2</span>, <span class="number">2</span>) <span class="comment"># 维度减半</span></span><br><span class="line"> self.fc = nn.Linear(<span class="number">28</span>*<span class="number">28</span>, <span class="number">10</span>)</span><br><span class="line"> self.softmax = nn.LogSoftmax(dim=<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> x = F.relu(self.conv(x))</span><br><span class="line"> x = F.relu(self.pool(x))</span><br><span class="line"> x = self.fc(x.view(x.shape[<span class="number">0</span>], <span class="number">-1</span>))</span><br><span class="line"> out = self.softmax(x)</span><br><span class="line"> <span class="keyword">return</span> out</span><br></pre></td></tr></table></figure><p>这个模型里,每一个 batch 经过卷积层以前的维度是 <code>[batch, 1, 28, 28]</code>,经过卷积层后长宽不变而通道数变成了 4;通过池化层以后每个 batch 的维度变成了 <code>[batch, 4, 14, 14]</code>,所以全连接层的入度不变。还有一点要注意的是因为 CNN 接受一个二维张量。所以打平这个操作要放在模型里面的全连接层之前,而不是训练中。训练 15 个 epoch:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">Epoch <span class="number">0</span> - Training loss: <span class="number">0.5372020660528242</span></span><br><span class="line">Epoch <span class="number">1</span> - Training loss: <span class="number">0.25464658567836795</span></span><br><span class="line">Epoch <span class="number">2</span> - Training loss: <span class="number">0.19804853362156383</span></span><br><span class="line">Epoch <span class="number">3</span> - Training loss: <span class="number">0.1687760797144571</span></span><br><span class="line">Epoch <span class="number">4</span> - Training loss: <span class="number">0.15073536825316675</span></span><br><span class="line">Epoch <span class="number">5</span> - Training loss: <span class="number">0.13678724837126033</span></span><br><span class="line">Epoch <span class="number">6</span> - Training loss: <span class="number">0.1266822514833132</span></span><br><span class="line">Epoch <span class="number">7</span> - Training loss: <span class="number">0.11664468624781985</span></span><br><span class="line">Epoch <span class="number">8</span> - Training loss: <span class="number">0.10935285677617071</span></span><br><span class="line">Epoch <span class="number">9</span> - Training loss: <span class="number">0.1023956656144229</span></span><br><span class="line">Epoch <span class="number">10</span> - Training loss: <span class="number">0.09896873006684535</span></span><br><span class="line">Epoch <span class="number">11</span> - Training loss: <span class="number">0.09299984435115986</span></span><br><span class="line">Epoch <span class="number">12</span> - Training loss: <span class="number">0.08871795376762748</span></span><br><span class="line">Epoch <span class="number">13</span> - Training loss: <span class="number">0.08644302016886662</span></span><br><span class="line">Epoch <span class="number">14</span> - Training loss: <span class="number">0.08259313310315805</span></span><br></pre></td></tr></table></figure><p>上一次使用全连接神经网络训练 15 轮后的 loss 是 0.27,看来 CNN 网络的效果好很多。测试一下:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>correct_count, all_count = <span class="number">0</span>, <span class="number">0</span></span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span><span class="keyword">with</span> torch.no_grad():</span><br><span class="line"><span class="meta">... </span> <span class="keyword">for</span> data <span class="keyword">in</span> valloader:</span><br><span class="line"><span class="meta">... </span> images, labels = data</span><br><span class="line"><span class="meta">... </span> outputs = cnn(images)</span><br><span class="line"><span class="meta">... </span> _, predicted = torch.max(outputs.data, <span class="number">1</span>)</span><br><span class="line"><span class="meta">... </span> all_count += labels.size(<span class="number">0</span>)</span><br><span class="line"><span class="meta">... </span> correct_count += (predicted == labels).sum().item()</span><br><span class="line"></span><br><span class="line"><span class="meta">... </span>print(<span class="string">"Number Of Images Tested ="</span>, all_count)</span><br><span class="line"><span class="meta">... </span>print(<span class="string">"\nModel Accuracy ="</span>, (correct_count/all_count))</span><br><span class="line">Number Of Images Tested = <span class="number">10000</span></span><br><span class="line"></span><br><span class="line">Model Accuracy = <span class="number">0.9705</span></span><br></pre></td></tr></table></figure><p>准确率果然超过了 97%。CNN YES!</p><p>欢迎关注我的微信公众号“花解语 NLP”:<br><img src="https://img-blog.csdnimg.cn/20200514100635366.jpg#pic_center" alt=""></p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 11:使用全连接网络进行手写数字识别</title>
<link href="2020/05/27/DL-PyTorch-%E6%8A%98%E6%A1%82-11%EF%BC%9A%E4%BD%BF%E7%94%A8%E5%85%A8%E8%BF%9E%E6%8E%A5%E7%BD%91%E7%BB%9C%E8%BF%9B%E8%A1%8C%E6%89%8B%E5%86%99%E6%95%B0%E5%AD%97%E8%AF%86%E5%88%AB/"/>
<url>2020/05/27/DL-PyTorch-%E6%8A%98%E6%A1%82-11%EF%BC%9A%E4%BD%BF%E7%94%A8%E5%85%A8%E8%BF%9E%E6%8E%A5%E7%BD%91%E7%BB%9C%E8%BF%9B%E8%A1%8C%E6%89%8B%E5%86%99%E6%95%B0%E5%AD%97%E8%AF%86%E5%88%AB/</url>
<content type="html"><![CDATA[<p>光说不练假把式,现在我们已经积累了那么多的 PyTorch 知识,让我们实践一下吧!</p><p>本文从简单的手写数字识别入手,参考了若干文章:</p><ul><li><a href="https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627" target="_blank" rel="noopener" title="Handwritten Digit Recognition Using PyTorch — Intro To Neural Networks">Handwritten Digit Recognition Using PyTorch — Intro To Neural Networks</a></li><li><a href="https://www.pluralsight.com/guides/building-your-first-pytorch-solution" target="_blank" rel="noopener" title="Building Your First PyTorch Solution">Building Your First PyTorch Solution</a></li></ul><h1 id="1-PyTotch-使用总览"><a href="#1-PyTotch-使用总览" class="headerlink" title="1. PyTotch 使用总览"></a>1. PyTotch 使用总览</h1><p>使用 PyTorch 进行深度学习的步骤主要分以下七步:</p><ol><li>准备数据,包括数据的预处理和封装;</li><li>模型搭建;</li><li>选择损失函数;</li><li>选择优化器;</li><li>迭代训练;</li><li>评估模型;</li><li>保存模型。<a id="more"></a>其实第 3 - 5 步反而是最简单的,复杂的地方主要集中在第 1、6 步上。在本文中,我们将使用 PyTorch 搭建一个全连接神经网络,用来识别 MNIST 数据框中的手写数字。本文不涉及 GPU 的使用。</li></ol><h1 id="2-PyTorch实践"><a href="#2-PyTorch实践" class="headerlink" title="2. PyTorch实践"></a>2. PyTorch实践</h1><h2 id="2-0-载入必须的库"><a href="#2-0-载入必须的库" class="headerlink" title="2.0 载入必须的库"></a>2.0 载入必须的库</h2><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> numpy <span class="keyword">as</span> np</span><br><span class="line"><span class="keyword">import</span> matplotlib.pyplot <span class="keyword">as</span> plt</span><br><span class="line"></span><br><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"><span class="keyword">import</span> torchvision</span><br><span class="line"><span class="keyword">from</span> torchvision <span class="keyword">import</span> datasets, transforms</span><br><span class="line"><span class="keyword">from</span> torch <span class="keyword">import</span> nn, optim</span><br></pre></td></tr></table></figure><h2 id="2-1-准备数据"><a href="#2-1-准备数据" class="headerlink" title="2.1 准备数据"></a>2.1 准备数据</h2><p>首先简单介绍一下我们使用的数据库。MNIST 数据库由美国国家标准和科技局推出,包含了 70000 张手写的 0 - 9 的图片,每个数字 7000 张。训练集 60000 张,测试集 10000 张。每张图片都经过预处理,转换成了 28*28 尺寸的一维黑白图像。</p><h3 id="2-1-1-获取数据"><a href="#2-1-1-获取数据" class="headerlink" title="2.1.1 获取数据"></a>2.1.1 获取数据</h3><p>我们从 <code>torchvision.datasets</code> 获取数据:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">transform = transforms.Compose([transforms.ToTensor(),</span><br><span class="line"> transforms.Normalize((<span class="number">0.5</span>,), (<span class="number">0.5</span>,)),</span><br><span class="line"> ])</span><br><span class="line"></span><br><span class="line">trainset = datasets.MNIST(<span class="string">'./data'</span>, download=<span class="literal">True</span>, train=<span class="literal">True</span>, transform=transform)</span><br><span class="line">testset = datasets.MNIST(<span class="string">'./data'</span>, download=<span class="literal">True</span>, train=<span class="literal">False</span>, transform=transform)</span><br></pre></td></tr></table></figure><p><code>transforms</code> 库在下载图像数据时会对数据进行处理。<code>transforms.ToTensor()</code> 将一个维度为 <code>(H x W x C)</code> 的 RGB 文件转换为一个维度为 <code>(C x H x W)</code> 的张量,数值范围从 <code>[0, 255]</code> 转换为 <code>[0, 1]</code>。<code>transforms.Normalize()</code> 将数据进行标准化处理,使其满足正态分布。<code>transforms.Compose()</code> 将所有转换打包。</p><p>设置好下载时的预处理方式,我们就可以下载数据了。第一个参数 <code>'./data'</code> 指定了保存的地址,第三个参数 <code>train</code> 的值 <code>True</code> 和 <code>False</code> 分别对应了训练集和测试集。</p><h3 id="2-1-2-封装数据"><a href="#2-1-2-封装数据" class="headerlink" title="2.1.2 封装数据"></a>2.1.2 封装数据</h3><p>我们不能把 60000 个图片一次全部给神经网络,需要按照 batch 的尺寸分批给。有时候在给之前还要进行随机选择。关于封装数据,请见前文<a href="https://vincent507cpu.github.io/2020/05/14/DL-PyTorch-折桂-5:PyTorch-模块总览-torch-utils-data/">《[DL] PyTorch 折桂 5:PyTorch 模块总览 & torch.utils.data》</a>。这一次我们设置 batch size 为 64.</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">trainloader = torch.utils.data.DataLoader(trainset, batch_size=<span class="number">64</span>, shuffle=<span class="literal">True</span>)</span><br><span class="line">testloader = torch.utils.data.DataLoader(testset, batch_size=<span class="number">64</span>, shuffle=<span class="literal">True</span>)</span><br></pre></td></tr></table></figure><h3 id="2-1-3-exploratory-data-analysis-EDA"><a href="#2-1-3-exploratory-data-analysis-EDA" class="headerlink" title="2.1.3 exploratory data analysis (EDA)"></a>2.1.3 exploratory data analysis (EDA)</h3><p>拿到数据以后很重要的步骤是对数据进行基本的了解,包括数量、维度等等。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>len(trainset)</span><br><span class="line"><span class="number">60000</span></span><br><span class="line"><span class="meta">>>> </span>len(testset)</span><br><span class="line"><span class="number">10000</span></span><br></pre></td></tr></table></figure><p>接下来对图片进行可视化:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">show_batch</span><span class="params">(batch)</span>:</span></span><br><span class="line"> im = torchvision.utils.make_grid(batch)</span><br><span class="line"> plt.imshow(np.transpose(im.numpy(), (<span class="number">1</span>, <span class="number">2</span>, <span class="number">0</span>)))</span><br><span class="line"></span><br><span class="line">dataiter = iter(trainloader)</span><br><span class="line">images, labels = dataiter.next()</span><br><span class="line"></span><br><span class="line">show_batch(images)</span><br></pre></td></tr></table></figure><p><img src="https://img-blog.csdnimg.cn/20200527195745560.png" alt=""><br>查看图片的尺寸:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>images[<span class="number">0</span>].shape</span><br><span class="line">torch.Size([<span class="number">1</span>, <span class="number">28</span>, <span class="number">28</span>])</span><br></pre></td></tr></table></figure><h2 id="2-2-搭建模型"><a href="#2-2-搭建模型" class="headerlink" title="2.2 搭建模型"></a>2.2 搭建模型</h2><p>搭建模型有两种方法:简单但稍欠灵活性的 <code>nn.Sequential</code> 和相反的模块化搭建方法。因为后续还会有实战,这次我们仅仅搭建一个最简单的一层全连接网络。关于搭建模型使用的 <code>nn.Module</code> 的详情请看 <a href="https://vincent507cpu.github.io/2020/05/14/DL-PyTorch-折桂-6:torch-nn-Module/">《[DL] PyTorch 折桂 6:torch.nn.Module》</a>。</p><p>首先来看如何使用 <code>nn.Sequential</code>:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">model = nn.Sequential(nn.Linear(<span class="number">28</span>\*<span class="number">28</span>, <span class="number">10</span>),</span><br><span class="line"> nn.LogSoftmax(dim=<span class="number">1</span>))</span><br></pre></td></tr></table></figure><p>我们也可以使用模块化方式搭建模型,与 <code>nn.Sequential</code> 方法搭建的模型时等价的:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">Net</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></span><br><span class="line"> super(Net, self).__init__()</span><br><span class="line"> self.fc = nn.Linear(<span class="number">28</span>*<span class="number">28</span>, <span class="number">10</span>)</span><br><span class="line"> self.softmax = nn.LogSoftmax(dim=<span class="number">1</span>)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> x = self.fc(x)</span><br><span class="line"> x = self.softmax(x)</span><br><span class="line"> <span class="keyword">return</span> x</span><br><span class="line"></span><br><span class="line">model = Net() <span class="comment"># 模块化构建神经网络需要先实例化</span></span><br></pre></td></tr></table></figure><p>因为全连接层使用矩阵乘法进行运算,输入应该是一个一维向量,而且输入的最后一维的维度要与全连接层的入度相同。所以我们需要先把一个图片打平(后面训练的时候做),然后将打平后的长度作为全连接层的入度。对于一个分类模型来说,全连接层的出度是分类的数量。因为我们想对 0 - 9 一共 10 个数字进行分类,所以出度为 10。</p><p>重点说一下 <code>nn.LogSoftmax</code>。softmax 是分类任务中常用的手段,将目标值转化为范围为 $(0,1)$ 之间的,所有值的和为 1 的概率分布。因为 softmax 的计算公式为 $\frac{e^{x_i}}{\sum e^{x_i}}$,如果 $x$ 过小会导致它的概率极小,超过 Python 的数据精度而为 0,所以我们一般对概率分布取对数,将概率分布转化为 $(-\infty,0)$ 的分布。<code>nn.LogSoftmax</code> 就是进行这个运算的的类。<code>nn.LogSoftmax</code> 对应的损失函数为 <code>nn.NLLLoss</code>。因为我们要对第二维进行似然估计,所以明确 <code>dim=1</code>。</p><p>关于损失函数的具体介绍请看<a href="https://vincent507cpu.github.io/2020/05/18/DL-PyTorch-折桂-9:损失函数/">《[DL] PyTorch 折桂 9:损失函数》</a>。</p><h2 id="2-3-损失函数"><a href="#2-3-损失函数" class="headerlink" title="2.3 损失函数"></a>2.3 损失函数</h2><p>上面已经提到了,如果使用 <code>nn.LogSoftmax</code> 作为模型的输出,损失函数应该使用 <code>nn.NLLLoss</code>。这里不多赘述。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">criterion = nn.NLLLoss()</span><br></pre></td></tr></table></figure><h2 id="2-4-优化器"><a href="#2-4-优化器" class="headerlink" title="2.4 优化器"></a>2.4 优化器</h2><p><a href="https://vincent507cpu.github.io/2020/05/24/DL-PyTorch-折桂-10:torch-optim/">《[DL] PyTorch 折桂 10:torch.optim》</a> 提到,通常我们可以无脑选择 <code>torch.optim.Adam</code>。但是 MNIST 手写数字识别是一个非常简单的任务,使用 SGD 足矣,这次我们使用 <code>torch.optim.SGD</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">optimizer = optim.SGD(model.parameters(), lr=<span class="number">0.003</span>, momentum=<span class="number">0.9</span>)</span><br></pre></td></tr></table></figure><h2 id="2-5-迭代训练"><a href="#2-5-迭代训练" class="headerlink" title="2.5 迭代训练"></a>2.5 迭代训练</h2><p>每一次的训练的流程如下:</p><ol><li>优化器的导数记录清零;</li><li>使用模型得到预测值;</li><li>使用损失函数计算预测值与真实值之间的损失;</li><li>反向传播;</li><li>更新权重。</li></ol><p>因为优化器里的导数是累积的,在每一轮训练中都要执行第一步,在第四步前还是第五步后无所谓。此外可以根据需要加入进度报告。代码如下:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">for</span> e <span class="keyword">in</span> range(epochs):</span><br><span class="line"> running_loss = <span class="number">0</span></span><br><span class="line"> <span class="keyword">for</span> images, labels <span class="keyword">in</span> trainloader:</span><br><span class="line"> images = images.view(images.shape[<span class="number">0</span>], <span class="number">-1</span>) <span class="comment"># 打平数据</span></span><br><span class="line"> </span><br><span class="line"> optimizer.zero_grad() <span class="comment"># 导数清零</span></span><br><span class="line"> output = model(images) <span class="comment"># 得到预测值</span></span><br><span class="line"> loss = criterion(output, labels) <span class="comment"># 计算损失</span></span><br><span class="line"> </span><br><span class="line"> loss.backward() <span class="comment"># 反向传播</span></span><br><span class="line"> optimizer.step() <span class="comment"># 优化权重</span></span><br><span class="line"> </span><br><span class="line"> running_loss += loss.item()</span><br><span class="line"> <span class="keyword">else</span>:</span><br><span class="line"> print(<span class="string">"Epoch {} - Training loss: {}"</span>.format(e, running_loss/len(trainloader)))</span><br></pre></td></tr></table></figure><p>我们设置 <code>epochs = 15</code> 运行一下:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br></pre></td><td class="code"><pre><span class="line">Epoch <span class="number">0</span> - Training loss: <span class="number">0.46956403385093215</span></span><br><span class="line">Epoch <span class="number">1</span> - Training loss: <span class="number">0.33383476238515075</span></span><br><span class="line">Epoch <span class="number">2</span> - Training loss: <span class="number">0.31380205746017287</span></span><br><span class="line">Epoch <span class="number">3</span> - Training loss: <span class="number">0.3029081499509847</span></span><br><span class="line">Epoch <span class="number">4</span> - Training loss: <span class="number">0.2956352831442346</span></span><br><span class="line">Epoch <span class="number">5</span> - Training loss: <span class="number">0.2905418651063305</span></span><br><span class="line">Epoch <span class="number">6</span> - Training loss: <span class="number">0.2873595496103453</span></span><br><span class="line">Epoch <span class="number">7</span> - Training loss: <span class="number">0.2838163320173714</span></span><br><span class="line">Epoch <span class="number">8</span> - Training loss: <span class="number">0.2816906003777915</span></span><br><span class="line">Epoch <span class="number">9</span> - Training loss: <span class="number">0.27968987264930567</span></span><br><span class="line">Epoch <span class="number">10</span> - Training loss: <span class="number">0.27738782898512987</span></span><br><span class="line">Epoch <span class="number">11</span> - Training loss: <span class="number">0.2752566468248616</span></span><br><span class="line">Epoch <span class="number">12</span> - Training loss: <span class="number">0.27330243247134217</span></span><br><span class="line">Epoch <span class="number">13</span> - Training loss: <span class="number">0.2733802362513949</span></span><br><span class="line">Epoch <span class="number">14</span> - Training loss: <span class="number">0.27021837964463336</span></span><br></pre></td></tr></table></figure><p>可以看到,模型似乎在学习,在第 10 个 epoch 稳定。</p><h2 id="2-6-评估模型"><a href="#2-6-评估模型" class="headerlink" title="2.6 评估模型"></a>2.6 评估模型</h2><p>PyTorch 没有 TensorFlow 方便的评估功能,所有评估都要手工定义。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br></pre></td><td class="code"><pre><span class="line">correct_count, all_count = <span class="number">0</span>, <span class="number">0</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">for</span> images,labels <span class="keyword">in</span> valloader:</span><br><span class="line"> <span class="keyword">for</span> i <span class="keyword">in</span> range(len(labels)):</span><br><span class="line"> img = images[i].view(<span class="number">1</span>, <span class="number">784</span>) <span class="comment"># 取出第 i 个元素</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">with</span> torch.no_grad(): <span class="comment"># 关闭求导功能</span></span><br><span class="line"> logps = model(img) <span class="comment"># 获得预测值</span></span><br><span class="line"></span><br><span class="line"> </span><br><span class="line"> ps = torch.exp(logps) <span class="comment"># 将对数去掉</span></span><br><span class="line"> pred_label = torch.argmax(ps[<span class="number">0</span>]) <span class="comment"># 获得最大概率的标签</span></span><br><span class="line"> true_label = labels[i] <span class="comment"># 获得真实数据的标签</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">if</span>(true_label == pred_label): <span class="comment"># 如果预测与真实值相同则加 1</span></span><br><span class="line"> correct_count += <span class="number">1</span></span><br><span class="line"> </span><br><span class="line"> all_count += <span class="number">1</span></span><br><span class="line"></span><br><span class="line">print(<span class="string">"Number Of Images Tested ="</span>, all_count)</span><br><span class="line">print(<span class="string">"Model Accuracy ="</span>, (correct_count/all_count))</span><br></pre></td></tr></table></figure><p>详情见代码评论。这里只说一点:在测试的时候我们不需要模型进行更新,关闭模型更新的方法除了代码里的 <code>with torch.no_grad()</code> 以外,还可以使用 <code>model.eval()</code>。<br>我们看一下测试结果:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">Number Of Images Tested = <span class="number">10000</span></span><br><span class="line">Model Accuracy = <span class="number">0.9222</span></span><br></pre></td></tr></table></figure><p>我们的模型仅仅使用了一个全连接层就获得了 92.2%的准确率,如果我们加入多个全连接层并且使用 dropout 等方法,准确率可以轻松超过 97%。</p><h2 id="2-7-保存模型"><a href="#2-7-保存模型" class="headerlink" title="2.7 保存模型"></a>2.7 保存模型</h2><p>PyTorch 的模型文件的扩展名一般是 <code>pt</code> 或 <code>pth</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">torch.save(model, <span class="string">'./my_mnist_model.pt'</span>)</span><br></pre></td></tr></table></figure>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 10:torch.optim</title>
<link href="2020/05/24/DL-PyTorch-%E6%8A%98%E6%A1%82-10%EF%BC%9Atorch-optim/"/>
<url>2020/05/24/DL-PyTorch-%E6%8A%98%E6%A1%82-10%EF%BC%9Atorch-optim/</url>
<content type="html"><![CDATA[<h1 id="1-优化器"><a href="#1-优化器" class="headerlink" title="1. 优化器"></a>1. 优化器</h1><p>优化器就是根据导数对参数进行更新的类,不同的优化器本质上都是梯度下降法,只是在实现的细节上有所不同。类似的,PyTorch 里的所有优化器都继承自 <code>torch.optim.Optimizer</code> 这个基类。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">torch.optim.Optimizer(params, defaults)</span><br></pre></td></tr></table></figure><p><code>params</code> 是优化器要优化的权重,是一个迭代器;<code>defaults</code> 是优化器在参数以外的默认参数,根据所被继承的类有所不同。</p><a id="more"></a><h2 id="1-1-优化器的种类"><a href="#1-1-优化器的种类" class="headerlink" title="1.1 优化器的种类"></a>1.1 <a href="https://zhuanlan.zhihu.com/p/64885176" target="_blank" rel="noopener" title="PyTorch 学习笔记(七):PyTorch的十个优化器">优化器的种类</a></h2><ul><li><code>torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False)</code></li></ul><p>基础优化器,可以使用 <code>momentum</code> 来避免陷入 local minima。</p><ul><li><code>torch.optim.ASGD</code>:SGD 的改进版,使用平均随机梯度下降。</li></ul><p>下面的若干种优化器都来自于同一个算法:Adaptive Gradient estimation,自适应梯度估计。</p><ul><li><code>torch.optim.Rprop</code>:实现 resilient backpropagation algorithm,弹性方向传播。不适用于 mini-batch,因此现在较少使用。</li><li><code>torch.optim.Adagrad</code>:Adagrad 是一种自适应优化方法,是自适应的为各个参数分配不同的学习率。这个学习率的变化,会受到梯度的大小和迭代次数的影响。梯度越大,学习率越小;梯度越小,学习率越大。缺点是训练后期,学习率过小,因为 Adagrad 累加之前所有的梯度平方作为分母。</li><li><code>torch.optim.Adadelta</code>:实现 Adadelta 优化方法。Adadelta 是 Adagrad 的改进。Adadelta分母中采用距离当前时间点比较近的累计项,这可以避免在训练后期,学习率过小。</li><li><code>torch.optim.RMSprop</code>:实现 RMSprop 优化方法(Hinton提出),RMS 是均方根(root meam square)的意思。RMSprop 和 Adadelta 一样,也是对 Adagrad 的一种改进。RMSprop 采用均方根作为分母,可缓解 Adagrad 学习率下降较快的问题。并且引入均方根,可以减少摆动。</li><li><code>torch.optim.Adam</code>:Adam 是对上面的自适应算法的改进,是一种自适应学习率的优化方法,Adam 利用梯度的一阶矩估计和二阶矩估计动态的调整学习率。吴老师课上说过,Adam 是结合了 Momentum 和 RMSprop,并进行了偏差修正。</li><li><code>torch.optim.Adamax</code>:Adamax对Adam增加了一个学习率上限的概念。</li><li><code>torch.optim.SparseAdam</code>:由于稀疏张量的优化器。</li><li><code>torch.optim.LBFGS</code>:实现L-BFGS(Limited-memory Broyden–Fletcher–Goldfarb–Shanno)优化方法。L-BFGS属于拟牛顿算法。L-BFGS是对BFGS的改进,特点就是节省内存。<h2 id="1-2-创建优化器"><a href="#1-2-创建优化器" class="headerlink" title="1.2 创建优化器"></a>1.2 创建优化器</h2>可以看出,<code>Adam</code> 优化器是集大成的优化器,一般无脑使用 <code>Adam</code> 即可。本文以 <code>Adam</code> 为例。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">torch.optim.Adam(params, lr=<span class="number">0.001</span>, betas=(<span class="number">0.9</span>, <span class="number">0.999</span>), eps=<span class="number">1e-08</span>, weight_decay=<span class="number">0</span>, amsgrad=<span class="literal">False</span>)</span><br></pre></td></tr></table></figure></li><li><code>params (iterable)</code>:可用于迭代优化的参数或者定义参数组的dicts。</li><li><code>lr (float, optional)</code> :学习率(默认: 1e-3)</li><li><code>betas (Tuple[float, float], optional)</code>:用于计算梯度的平均和平方的系数(默认:(0.9, 0.999))</li><li><code>eps (float, optional)</code>:为了提高数值稳定性而添加到分母的一个项(默认:1e-8)</li><li><code>weight_decay (float, optional)</code>:权重衰减(如 L2 惩罚,默认: 0)</li></ul><p>对于 <code>torch.optim.Adam</code> 来说,只有 <code>params</code> 是必要的属性。可以以如下方法进行创建:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">optimizer = optim.Adam(model.parameters(), lr=<span class="number">0.0001</span>)</span><br></pre></td></tr></table></figure><p>也可以指定每个参数选项。 只需传递一个可迭代的 dict 来替换先前可迭代的 Variable。dict 中的每一项都可以定义为一个单独的参数组,参数组用一个 params 键来包含属于它的参数列表。其他键应该与优化器接受的关键字参数相匹配,才能用作此组的优化选项。比如:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">optim.SGD([</span><br><span class="line"> {<span class="string">'params'</span>: model.base.parameters()},</span><br><span class="line"> {<span class="string">'params'</span>: model.classifier.parameters(), <span class="string">'lr'</span>: <span class="number">1e-3</span>}</span><br><span class="line"> ], lr=<span class="number">1e-2</span>, momentum=<span class="number">0.9</span>)</span><br></pre></td></tr></table></figure><p>如上,<code>model.base.parameters()</code> 将使用 1e-2 的学习率,<code>model.classifier.parameters()</code> 将使用 1e-3 的学习率。0.9 的 <code>momentum</code> 作用于所有的 parameters。</p><h2 id="1-3-优化器的属性"><a href="#1-3-优化器的属性" class="headerlink" title="1.3 优化器的属性"></a>1.3 优化器的属性</h2><p>因为优化器都继承自 <code>torch.optim.Optimizer</code>,所以它们的属性相同。我们先构建一个优化器的实例:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>weight1 = torch.ones((<span class="number">2</span>, <span class="number">2</span>))</span><br><span class="line"><span class="meta">>>> </span>optimizer = torch.optim.Adam([weight], lr=<span class="number">1e-2</span>)</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(optimizer)</span><br><span class="line">Adam (</span><br><span class="line">Parameter Group <span class="number">0</span></span><br><span class="line"> amsgrad: <span class="literal">False</span></span><br><span class="line"> betas: (<span class="number">0.9</span>, <span class="number">0.999</span>)</span><br><span class="line"> eps: <span class="number">1e-08</span></span><br><span class="line"> lr: <span class="number">0.01</span></span><br><span class="line"> weight_decay: <span class="number">0</span></span><br><span class="line">)</span><br></pre></td></tr></table></figure><ul><li><code>param_group</code><br>返回优化器的参数组。参数组是一个列表,每个元素是一个组的字典。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>print(optimizer.param_groups)</span><br><span class="line">[{<span class="string">'params'</span>: [tensor([[<span class="number">1.</span>, <span class="number">1.</span>],</span><br><span class="line"> [<span class="number">1.</span>, <span class="number">1.</span>]], requires_grad=<span class="literal">True</span>)], <span class="string">'lr'</span>: <span class="number">0.01</span>, <span class="string">'betas'</span>: (<span class="number">0.9</span>, <span class="number">0.999</span>), <span class="string">'eps'</span>: <span class="number">1e-08</span>, <span class="string">'weight_decay'</span>: <span class="number">0</span>, <span class="string">'amsgrad'</span>: <span class="literal">False</span>}]</span><br></pre></td></tr></table></figure></li><li><code>add_param_group(param_group)</code><br>添加参数组。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>weight2 = torch.zeros(<span class="number">2</span>,<span class="number">2</span>)</span><br><span class="line"><span class="meta">>>> </span>optimizer.add_param_group({<span class="string">'params'</span>:weight2, <span class="string">'lr'</span>:<span class="number">0.01</span>})</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>optimizer.param_groups</span><br><span class="line">[{<span class="string">'params'</span>: [tensor([[<span class="number">1.</span>, <span class="number">1.</span>],</span><br><span class="line"> [<span class="number">1.</span>, <span class="number">1.</span>]], requires_grad=<span class="literal">True</span>)],</span><br><span class="line"> <span class="string">'lr'</span>: <span class="number">0.01</span>,</span><br><span class="line"> <span class="string">'betas'</span>: (<span class="number">0.9</span>, <span class="number">0.999</span>),</span><br><span class="line"> <span class="string">'eps'</span>: <span class="number">1e-08</span>,</span><br><span class="line"> <span class="string">'weight_decay'</span>: <span class="number">0</span>,</span><br><span class="line"> <span class="string">'amsgrad'</span>: <span class="literal">False</span>},</span><br><span class="line"> {<span class="string">'params'</span>: [tensor([[<span class="number">0.</span>, <span class="number">0.</span>],</span><br><span class="line"> [<span class="number">0.</span>, <span class="number">0.</span>]])],</span><br><span class="line"> <span class="string">'lr'</span>: <span class="number">0.01</span>,</span><br><span class="line"> <span class="string">'betas'</span>: (<span class="number">0.9</span>, <span class="number">0.999</span>),</span><br><span class="line"> <span class="string">'eps'</span>: <span class="number">1e-08</span>,</span><br><span class="line"> <span class="string">'weight_decay'</span>: <span class="number">0</span>,</span><br><span class="line"> <span class="string">'amsgrad'</span>: <span class="literal">False</span>}]</span><br></pre></td></tr></table></figure></li><li><code>state_dict()</code><br>返回优化器的状态。这个属性与 <code>param_group</code> 的区别在于 <code>state_dict()</code> 的返回值包含了梯度的状态。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>optimizer.state_dict()</span><br><span class="line">{<span class="string">'state'</span>: {}, <span class="comment"># 进行反向传播以前梯度为空</span></span><br><span class="line"> <span class="string">'param_groups'</span>: [{<span class="string">'lr'</span>: <span class="number">0.01</span>,</span><br><span class="line"> <span class="string">'betas'</span>: (<span class="number">0.9</span>, <span class="number">0.999</span>),</span><br><span class="line"> <span class="string">'eps'</span>: <span class="number">1e-08</span>,</span><br><span class="line"> <span class="string">'weight_decay'</span>: <span class="number">0</span>,</span><br><span class="line"> <span class="string">'amsgrad'</span>: <span class="literal">False</span>,</span><br><span class="line"> <span class="string">'params'</span>: [<span class="number">140685302219312</span>]}]}</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>optimizer.step()</span><br><span class="line"><span class="meta">>>> </span>optimizer.state_dict()</span><br><span class="line">{<span class="string">'state'</span>: {<span class="number">140685302219312</span>: {<span class="string">'step'</span>: <span class="number">1</span>, <span class="comment"># 反向传播以后有了状态</span></span><br><span class="line"> <span class="string">'exp_avg'</span>: tensor([[<span class="number">0.1000</span>, <span class="number">0.1000</span>],</span><br><span class="line"> [<span class="number">0.1000</span>, <span class="number">0.1000</span>]]),</span><br><span class="line"> <span class="string">'exp_avg_sq'</span>: tensor([[<span class="number">0.0010</span>, <span class="number">0.0010</span>],</span><br><span class="line"> [<span class="number">0.0010</span>, <span class="number">0.0010</span>]])}},</span><br><span class="line"> <span class="string">'param_groups'</span>: [{<span class="string">'lr'</span>: <span class="number">0.01</span>,</span><br><span class="line"> <span class="string">'betas'</span>: (<span class="number">0.9</span>, <span class="number">0.999</span>),</span><br><span class="line"> <span class="string">'eps'</span>: <span class="number">1e-08</span>,</span><br><span class="line"> <span class="string">'weight_decay'</span>: <span class="number">0</span>,</span><br><span class="line"> <span class="string">'amsgrad'</span>: <span class="literal">False</span>,</span><br><span class="line"> <span class="string">'params'</span>: [<span class="number">140685302219312</span>]},</span><br><span class="line"> {<span class="string">'lr'</span>: <span class="number">0.01</span>,</span><br><span class="line"> <span class="string">'betas'</span>: (<span class="number">0.9</span>, <span class="number">0.999</span>),</span><br><span class="line"> <span class="string">'eps'</span>: <span class="number">1e-08</span>,</span><br><span class="line"> <span class="string">'weight_decay'</span>: <span class="number">0</span>,</span><br><span class="line"> <span class="string">'amsgrad'</span>: <span class="literal">False</span>,</span><br><span class="line"> <span class="string">'params'</span>: [<span class="number">140685312958784</span>]}]}</span><br></pre></td></tr></table></figure></li><li><code>load_state_dict(state_dict)</code></li><li>载入已经保存的参数组。这个属性与模型的保存于载入一并介绍。</li><li><code>step()</code><br>执行一次反向传播。</li><li><code>zero_grad()</code><br>将优化器内存储的梯度清零。<h1 id="2-改变学习率"><a href="#2-改变学习率" class="headerlink" title="2. 改变学习率"></a>2. 改变学习率</h1><code>torch.optim.lr_scheduler</code> 中提供了基于多种 epoch 数目调整学习率的方法。优化器需要被包含进 scheduler 实例里。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>scheduler = ...(optimizer, ...) <span class="comment">#优化器被包含进来</span></span><br><span class="line"><span class="meta">>>> </span><span class="keyword">for</span> epoch <span class="keyword">in</span> range(<span class="number">100</span>):</span><br><span class="line"><span class="meta">>>> </span> train(...)</span><br><span class="line"><span class="meta">>>> </span> validate(...)</span><br><span class="line"><span class="meta">>>> </span> scheduler.step()</span><br></pre></td></tr></table></figure></li><li><code>torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)</code>:将每一个参数组的学习率设置为初始学习率 lr 的某个函数倍;<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="comment"># Assuming optimizer has two groups.</span></span><br><span class="line"><span class="meta">>>> </span>lambda1 = <span class="keyword">lambda</span> epoch: epoch // <span class="number">30</span></span><br><span class="line"><span class="meta">>>> </span>lambda2 = <span class="keyword">lambda</span> epoch: <span class="number">0.95</span> ** epoch</span><br><span class="line"><span class="meta">>>> </span>scheduler = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2])</span><br><span class="line"><span class="meta">>>> </span><span class="keyword">for</span> epoch <span class="keyword">in</span> range(<span class="number">100</span>):</span><br><span class="line"><span class="meta">>>> </span> train(...)</span><br><span class="line"><span class="meta">>>> </span> validate(...)</span><br><span class="line"><span class="meta">>>> </span> scheduler.step()</span><br></pre></td></tr></table></figure></li><li><code>torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda, last_epoch=-1)</code>:设置每个参数组的学习率为 $lr*\lambda^n,n=\frac{epoch}{step_size}$;<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>lmbda = <span class="keyword">lambda</span> epoch: <span class="number">0.95</span></span><br><span class="line"><span class="meta">>>> </span>scheduler = MultiplicativeLR(optimizer, lr_lambda=lmbda)</span><br><span class="line"><span class="meta">>>> </span><span class="keyword">for</span> epoch <span class="keyword">in</span> range(<span class="number">100</span>):</span><br><span class="line"><span class="meta">>>> </span> train(...)</span><br><span class="line"><span class="meta">>>> </span> validate(...)</span><br><span class="line"><span class="meta">>>> </span> scheduler.step()</span><br></pre></td></tr></table></figure></li><li><code>torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)</code>:设置每个参数组的学习率在每 step_size 时变化一次;<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="comment"># Assuming optimizer uses lr = 0.05 for all groups</span></span><br><span class="line"><span class="meta">>>> </span><span class="comment"># lr = 0.05 if epoch < 30</span></span><br><span class="line"><span class="meta">>>> </span><span class="comment"># lr = 0.005 if 30 <= epoch < 60</span></span><br><span class="line"><span class="meta">>>> </span><span class="comment"># lr = 0.0005 if 60 <= epoch < 90</span></span><br><span class="line"><span class="meta">>>> </span><span class="comment"># ...</span></span><br><span class="line"><span class="meta">>>> </span>scheduler = StepLR(optimizer, step_size=<span class="number">30</span>, gamma=<span class="number">0.1</span>)</span><br><span class="line"><span class="meta">>>> </span><span class="keyword">for</span> epoch <span class="keyword">in</span> range(<span class="number">100</span>):</span><br><span class="line"><span class="meta">>>> </span> train(...)</span><br><span class="line"><span class="meta">>>> </span> validate(...)</span><br><span class="line"><span class="meta">>>> </span> scheduler.step()</span><br></pre></td></tr></table></figure></li><li><code>torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1)</code>:设置每个参数组的学习率在达到 milestone 时变化。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="comment"># Assuming optimizer uses lr = 0.05 for all groups</span></span><br><span class="line"><span class="meta">>>> </span><span class="comment"># lr = 0.05 if epoch < 30</span></span><br><span class="line"><span class="meta">>>> </span><span class="comment"># lr = 0.005 if 30 <= epoch < 80</span></span><br><span class="line"><span class="meta">>>> </span><span class="comment"># lr = 0.0005 if epoch >= 80</span></span><br><span class="line"><span class="meta">>>> </span>scheduler = MultiStepLR(optimizer, milestones=[<span class="number">30</span>,<span class="number">80</span>], gamma=<span class="number">0.1</span>)</span><br><span class="line"><span class="meta">>>> </span><span class="keyword">for</span> epoch <span class="keyword">in</span> range(<span class="number">100</span>):</span><br><span class="line"><span class="meta">>>> </span> train(...)</span><br><span class="line"><span class="meta">>>> </span> validate(...)</span><br><span class="line"><span class="meta">>>> </span> scheduler.step()</span><br></pre></td></tr></table></figure>它们的公共参数:</li><li><code>lr_lambda</code>:描述学习率变化的匿名函数;</li><li><code>gamma</code>:倍数系数;</li><li><code>last_epoch</code>:最后一次 epoch 的索引,若为 -1 则为初始 epoch。</li></ul><p>欢迎关注我的微信公众号“花解语 NLP”:<br><img src="https://img-blog.csdnimg.cn/20200514100635366.jpg#pic_center" alt="在这里插入图片描述"></p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 9:损失函数</title>
<link href="2020/05/18/DL-PyTorch-%E6%8A%98%E6%A1%82-9%EF%BC%9A%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0/"/>
<url>2020/05/18/DL-PyTorch-%E6%8A%98%E6%A1%82-9%EF%BC%9A%E6%8D%9F%E5%A4%B1%E5%87%BD%E6%95%B0/</url>
<content type="html"><![CDATA[<h1 id="1-损失函数总览"><a href="#1-损失函数总览" class="headerlink" title="1. 损失函数总览"></a>1. 损失函数总览</h1><p>PyTorch 的 Loss Function(损失函数)都在 <code>torch.nn.functional</code> 里,也提供了封装好的类在 <code>torch.nn</code> 里。PyTorch 里有关有 18 个损失函数,常用的有 5 个,分别是:</p><ol><li>回归模型:</li></ol><ul><li><code>torch.nn.L1Loss</code></li><li><code>torch.nn.MSELoss</code></li></ul><ol start="2"><li>分类模型:</li></ol><ul><li><code>torch.nn.BCELoss</code></li><li><code>torch.nn.BCEWithLogitsLoss</code></li><li><code>torch.nn.CrossEntropyLoss</code></li><li><code>torch.nn.NLLLoss</code><a id="more"></a></li></ul><p>损失函数是用来衡量模型的单个预测与真实值的差异的:<br>$$Loss=f(\hat{y}-y)$$<br>还有额外的两个概念:Cost Function(代价函数)是 N 个预测值的损失函数平均值:<br>$$Cost=\frac{1}{N}\sum^N_if(\hat{y_i}-y_i)$$<br>而 Objective Function(目标函数)是最终需要优化的函数:<br>$$Obj=Cost+Regularization$$</p><p>还有其它的损失函数,学识有限,暂时不理解。希望以后有缘能够接触。</p><h1 id="2-回归损失函数"><a href="#2-回归损失函数" class="headerlink" title="2. 回归损失函数"></a>2. 回归损失函数</h1><p>回归模型有两种方法进行评估:MAE(mean absolute error) 和 MSE(mean squared error)。</p><ul><li><code>torch.nn.L1Loss(reduction='mean')</code></li></ul><p>这个类对应了 MAE 损失函数:<br>$$\ell=L={l_1,…l_n},\quad l_n=|\hat{y}-y|$$</p><ul><li><code>torch.nn.MSELoss(reduction='mean')</code></li></ul><p>这个类对应了 MSE 损失函数:<br>$$\ell=L={l_1,…l_n},\quad l_n=(\hat{y}-y)^2$$<br>上面两个类中的 <code>reduction</code> 规定了获得 $\ell$ 后的行为,有 <code>none</code>、<code>sum</code> 和 <code>mean</code> 三个。<code>none</code> 表示不对 $\ell$ 进行任何处理;<code>sum</code> 表示对 $\ell$ 进行求和;<code>mean</code> 表示对 $\ell$ 进行平均。默认为求平均。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>y = torch.tensor([<span class="number">1.1</span>, <span class="number">1.2</span>, <span class="number">1.3</span>])</span><br><span class="line"><span class="meta">>>> </span>y_hat = torch.tensor([<span class="number">1.</span>, <span class="number">1.</span>, <span class="number">1.</span>])</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>criterion_none = nn.L1Loss(reduction=<span class="string">'none'</span>) <span class="comment"># 什么都不做</span></span><br><span class="line"><span class="meta">>>> </span>criterion_none(y_hat, y)</span><br><span class="line">tensor([<span class="number">0.1000</span>, <span class="number">0.2000</span>, <span class="number">0.3000</span>])</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>criterion_mean = nn.L1Loss(reduction=<span class="string">'mean'</span>) <span class="comment"># 求平均</span></span><br><span class="line"><span class="meta">>>> </span>criterion_mean(y_hat, y)</span><br><span class="line">tensor(<span class="number">0.2000</span>)</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>criterion_sum = nn.L1Loss(reduction=<span class="string">'sum'</span>) <span class="comment"># 求和</span></span><br><span class="line"><span class="meta">>>> </span>criterion_sum(y_hat, y)</span><br><span class="line">tensor(<span class="number">0.6000</span>)</span><br></pre></td></tr></table></figure><h1 id="3-分类损失函数"><a href="#3-分类损失函数" class="headerlink" title="3. 分类损失函数"></a>3. 分类损失函数</h1><h2 id="3-1-交叉熵"><a href="#3-1-交叉熵" class="headerlink" title="3.1 交叉熵"></a>3.1 <a href="https://charlesliuyx.github.io/2017/09/11/什么是信息熵、交叉熵和相对熵/" target="_blank" rel="noopener" title="【直观详解】信息熵、交叉熵和相对熵">交叉熵</a></h2><p>自信息是一个事件发生的概率的负对数:<br>$$I(x)=-log[p(x)]$$<br>信息熵用来描述一个事件的不确定性公式为<br>$$H(P)=-\sum^N_iP(x_i)logP(x_i)$$<br>一个确定的事件的信息熵为 0,一个事件越不确定,信息熵就越大。</p><p>交叉熵,用来衡量在给定的真实分布下,使用非真实分布指定的策略消除系统的不确定性所需要付出努力的大小,表达式为<br>$$H(P,Q)=-\sum^B_{i=1}P(x_i)logQ(x_i)$$<br>相对熵又叫 “K-L 散度”,用来描述预测事件对真实事件的概率偏差。<br>$$D_{KL}(P,Q)=E\bigg[log\frac{P(x)}{Q(x)}\bigg]\<br>=E\bigg[logP(x)-logQ(x)\bigg]\<br>=\sum^N_{i=1}P(x_i)[logP(x_i)-logQ(x_i)]\<br>=\sum^N_{i=1}P(x_i)logP(x_i)-\sum^N_{i=1}P(x_i)logQ(x_i)\<br>=H(P,Q)-H(P)$$<br>而交叉熵的表达式为<br>$$H(P,Q)=-\sum^N_{i=1}P(x_i)logQ(x_i)$$<br>可见 $H(P,Q)=H(P)+D_{KL}(P,Q)$,即交叉熵是信息熵和相对熵的和。上面的 $P$ 是事件的真实分布,$Q$ 是预测出来的分布。所以优化 $H(P,Q)$ 等价于优化 $H(Q)$,因为 $H(P)$ 是已知不变的。</p><h2 id="3-2-分类损失函数"><a href="#3-2-分类损失函数" class="headerlink" title="3.2 分类损失函数"></a>3.2 分类损失函数</h2><p>下面我们来了解最常用的四个分类损失函数。</p><ul><li><code>torch.nn.BCELoss(weight=None, reduction='mean')</code><br>这个类实现了二分类交叉熵。<br>$$l_n=-w_n[y_n\cdot logx_n+(1-y_n)\cdot log(1-x_n)]$$<br>使用这个类时要注意,输入值(<strong>不是分类</strong>)的范围要在 $(0,1)$ 之间,否则会报错。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>inputs = torch.tensor([[<span class="number">1</span>, <span class="number">2</span>], [<span class="number">2</span>, <span class="number">2</span>], [<span class="number">3</span>, <span class="number">4</span>], [<span class="number">4</span>, <span class="number">5</span>]], dtype=torch.float)</span><br><span class="line"><span class="meta">>>> </span>target = torch.tensor([[<span class="number">1</span>, <span class="number">0</span>], [<span class="number">1</span>, <span class="number">0</span>], [<span class="number">0</span>, <span class="number">1</span>], [<span class="number">0</span>, <span class="number">1</span>]], dtype=torch.float)</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>criterion = nn.BCELoss()</span><br><span class="line"><span class="meta">>>> </span>criterion(inputs, target)</span><br><span class="line">---------------------------------------------------------------------------</span><br><span class="line">RuntimeError Traceback (most recent call last)</span><br><span class="line">...</span><br><span class="line">RuntimeError: all elements of input should be between <span class="number">0</span> <span class="keyword">and</span> <span class="number">1</span></span><br></pre></td></tr></table></figure>通常可以先使用 <code>F.sigmoid</code> 处理一下数据。</li><li><code>torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)</code><br>与上面的 <code>torch.nn.BCELoss</code> 相似,只是 $x$ 先使用了 sigmoid 处理了一下,这样就不需要手动使用 sigmoid 的了。<br>$$l_n=-w_n[y_n\cdot log\sigma(x_n)+(1-y_n)\cdot log(1-\sigma(x_n))]$$</li><li><code>torch.nn.NLLLoss(weight=None, ignore_index=-100, reduction='mean')</code><br>NLLLoss 的全称为 “negative log likelihood loss”,其作用是实现负对数似然函数中的负号。<br>$$\ell=L={l_1,…,l_N},\quad l_n=-w_{y_n}x_{n,y_n}$$</li><li><code>torch.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean')</code><br>这个类结合了 <code>nn.LogSoftmax</code> 和 <code>nn.NLLLoss</code>。这个类的运算可以写成:<br>$$loss(class)=weight[class]\bigg(-\text{log}\bigg(\frac{\text{exp}(x[class])}{\sum_j\text{exp}(x[j])}\bigg)\bigg)\<br>=weight[class]\bigg(-x[class]+\text{log}\bigg(\sum_j\text{exp}(x[j]\bigg)\bigg)$$<br>对比上面 $H(P,Q)$ 的公式,因为已知的 $x$ 的事件概率已知,所以 $P(x)$ 为 1;因为是单个事件,所以 $\sum^N_{i=1}$ 也为 1。所以上面的式子就简化成了 $H(P,Q)=-logQ(x)$。然后我们需要把 $x[class]$ 归一化到一个概率分布中,所以使用 softmax。</li><li><code>torch.nn.KLDivLoss(reduction='mean')</code><br>这个类就是上面提到的相对熵。<br>$$l(x,y)=L={1_1,…l_N}, l_n=y_n\cdot(\text{log}\ y_n-x_n)$$<br>这几个类的参数类似,除了上面提到的 <code>reduction</code>,还有一个 <code>weight</code>,就是每一个类别的权重。下面用例子来解释交叉熵和 <code>weight</code> 是如何运作的。我们先定义一组数据,使用 numpy 推演一下:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br></pre></td><td class="code"><pre><span class="line">inputs = torch.tensor([[<span class="number">1</span>, <span class="number">1</span>], [<span class="number">1</span>, <span class="number">2</span>], [<span class="number">3</span>, <span class="number">3</span>]], dtype=torch.float)</span><br><span class="line">target = torch.tensor([<span class="number">0</span>, <span class="number">0</span>, <span class="number">1</span>],dtype=torch.long)</span><br><span class="line"></span><br><span class="line">idx = target[<span class="number">0</span>]</span><br><span class="line"></span><br><span class="line">input_ = inputs.detach().numpy()[idx] <span class="comment"># [1, 1]</span></span><br><span class="line">target_ = target.numpy()[idx] <span class="comment"># [0]</span></span><br><span class="line"></span><br><span class="line"><span class="comment"># 第一项</span></span><br><span class="line">x_class = input_[target_]</span><br><span class="line"></span><br><span class="line"><span class="comment"># 第二项</span></span><br><span class="line">sigma_exp_x = np.sum(list(map(np.exp, input_)))</span><br><span class="line">log_sigma_exp_x = np.log(sigma_exp_x)</span><br><span class="line"></span><br><span class="line"><span class="comment"># 输出 loss</span></span><br><span class="line">loss_1 = -x_class + log_sigma_exp_x</span><br></pre></td></tr></table></figure>结果为<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>print(<span class="string">"第一个样本loss为: "</span>, loss_1)</span><br><span class="line">第一个样本loss为: <span class="number">0.6931473</span></span><br></pre></td></tr></table></figure>现在我们再使用 PyTorch 来计算:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>criterion_ce = nn.CrossEntropyLoss(reduction=<span class="string">'none'</span>)</span><br><span class="line"><span class="meta">>>> </span>criterion_ce(inputs, target)</span><br><span class="line">tensor([<span class="number">0.6931</span>, <span class="number">1.3133</span>, <span class="number">0.6931</span>])</span><br></pre></td></tr></table></figure>可以看到,结果是一致的。现在我们再看看 <code>weight</code>:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>weight = torch.tensor([<span class="number">0.1</span>, <span class="number">0.9</span>], dtype=torch.float)</span><br><span class="line"><span class="meta">>>> </span>criterion_ce = nn.CrossEntropyLoss(weight=weight, reduction=<span class="string">'none'</span>)</span><br><span class="line"><span class="meta">>>> </span>criterion_ce(inputs, target)</span><br><span class="line">tensor([<span class="number">0.0693</span>, <span class="number">0.1313</span>, <span class="number">0.6238</span>])</span><br></pre></td></tr></table></figure>与没有权重的交叉熵进行比较后可以发现,每一个值都乘以了 $\frac{p_i}{\sum{p_i}}$。当 <code>reduction</code> 为 <code>sum</code> 和 <code>mean</code> 的时候,交叉熵的加权总和或者平均值再除以权重的和。<h2 id="3-3-总结"><a href="#3-3-总结" class="headerlink" title="3.3 总结"></a>3.3 总结</h2></li></ul><ol><li><code>F.sigmoid</code> + <code>torch.nn.BCELoss</code> = <code>torch.nn.BCEWithLogitsLoss</code></li><li><code>nn.LogSoftmax</code> + <code>nn.NLLLoss</code> = <code>torch.nn.CrossEntropyLoss</code></li></ol><p>欢迎关注我的微信公众号“花解语 NLP”:<br><img src="https://img-blog.csdnimg.cn/20200514100635366.jpg#pic_center" alt="在这里插入图片描述"></p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 8:torch.nn.init</title>
<link href="2020/05/18/DL-PyTorch-%E6%8A%98%E6%A1%82-8%EF%BC%9Atorch-nn-init/"/>
<url>2020/05/18/DL-PyTorch-%E6%8A%98%E6%A1%82-8%EF%BC%9Atorch-nn-init/</url>
<content type="html"><![CDATA[<h1 id="1-torch-nn-init-概述"><a href="#1-torch-nn-init-概述" class="headerlink" title="1. torch.nn.init 概述"></a>1. torch.nn.init 概述</h1><p>因为神经网络的训练过程其实是寻找最优解的过程,所以神经元的初始值非常重要。如果初始值恰好在最优解附近,神经网络的训练会非常简单。而当神经网络的层数增加以后,一个突出的问题就是梯度消失和梯度爆炸。前者指的是由于梯度接近 0,导致神经元无法进行更新;后者指的是<a href="https://zhuanlan.zhihu.com/p/32154263" target="_blank" rel="noopener" title="浅谈神经网络中的梯度爆炸问题">误差梯度在更新中累积得到一个非常大的梯度,这样的梯度会大幅度更新网络参数,进而导致网络不稳定</a>。</p><p><code>torch.nn.init</code> 模块提供了合理初始化初始值的方法。它一共提供了四类初始化方法:</p><ol><li>Xavier 分布初始化;</li><li>Kaiming 分布初始化;</li><li>均匀分布、正态分布、常数分布初始化;</li><li>其它初始化。<a id="more"></a>有梯度边界的激活函数如 <code>sigmoid</code>、<code>tanh</code> 和 <code>softmax</code> 等被称为饱和函数,没有梯度边界的激活函数如 <code>relu</code> 被称为不饱和函数,它们对应的初始化方法不同。<h1 id="2-梯度消失和梯度爆炸"><a href="#2-梯度消失和梯度爆炸" class="headerlink" title="2. 梯度消失和梯度爆炸"></a>2. 梯度消失和梯度爆炸</h1>假设我们有一个 3 层的全连接网络:<br><img src="https://img-blog.csdnimg.cn/2020051809553052.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDYxNDY4Nw==,size_16,color_FFFFFF,t_70#pic_center" alt="在这里插入图片描述"><br>对倒数第二层神经元的权重进行反向传播的公式为:<br>$$\Delta W_3=\frac{\partial loss}{\partial W_3}=\frac{\partial loss}{\partial out}*\frac{\partial out}{\partial H_3}*\frac{\partial H_3}{\partial W_3}$$<br>而 $H_3=H_2*W_3$,所以<br>$$\Delta W_3=\frac{\partial loss}{\partial out}*\frac{\partial out}{\partial H_3}*H_2$$<br>即 $Hi_2$ ,即上一层的神经元的输出值,决定了 $\Delta W_3$ 的大小。如果 $H_2$ 太大或太小,即梯度消失或梯度爆炸,将导致神经网络无法训练。对于 <code>sigmoid</code> 和 <code>tanh</code> 等梯度绝对值小于 1 的激活函数来说,神经元的值会越来越小;对于其它情况,假设我们构建了一个 100 层的全连接网络:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">MLP</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, neural_num, layers)</span>:</span></span><br><span class="line"> super(MLP, self).__init__()</span><br><span class="line"> self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=<span class="literal">False</span>) <span class="keyword">for</span> _ <span class="keyword">in</span> range(layers)])</span><br><span class="line"> self.neural_num = neural_num</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> <span class="keyword">for</span> (i, linear) <span class="keyword">in</span> enumerate(self.linears):</span><br><span class="line"> x = linear(x)</span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> x</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">init</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">for</span> m <span class="keyword">in</span> self.modules():</span><br><span class="line"> <span class="keyword">if</span> isinstance(m, nn.Linear):</span><br><span class="line"> nn.init.normal_(m.weight.data)</span><br><span class="line"></span><br><span class="line">layers=<span class="number">100</span></span><br><span class="line">neural_num=<span class="number">256</span></span><br><span class="line">batch_size=<span class="number">16</span></span><br><span class="line"></span><br><span class="line">net = MLP(neural_num, layers)</span><br><span class="line">net.init()</span><br><span class="line"></span><br><span class="line">inputs = torch.randn(batch_size, neural_num)</span><br><span class="line">output = net(inputs)</span><br></pre></td></tr></table></figure>打印一下神经网络的输出:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>print(output)</span><br><span class="line">tensor([[nan, nan, nan, ..., nan, nan, nan],</span><br><span class="line"> [nan, nan, nan, ..., nan, nan, nan],</span><br><span class="line"> [nan, nan, nan, ..., nan, nan, nan],</span><br><span class="line"> ...,</span><br><span class="line"> [nan, nan, nan, ..., nan, nan, nan],</span><br><span class="line"> [nan, nan, nan, ..., nan, nan, nan],</span><br><span class="line"> [nan, nan, nan, ..., nan, nan, nan]], grad_fn=<MmBackward>)</span><br></pre></td></tr></table></figure>可以看到,神经元的值都变成了 <code>nan</code>。这是为什么呢?</li></ol><p>因为方差可以表征数据的离散程度,让我们来打印一下每次神经元的值的方差:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line">layers: <span class="number">0</span>, std: <span class="number">15.7603178024292</span></span><br><span class="line">layers: <span class="number">1</span>, std: <span class="number">253.5698699951172</span></span><br><span class="line">layers: <span class="number">2</span>, std: <span class="number">4018.8212890625</span></span><br><span class="line">layers: <span class="number">3</span>, std: <span class="number">64962.9453125</span></span><br><span class="line">layers: <span class="number">4</span>, std: <span class="number">1050192.125</span></span><br><span class="line">layers: <span class="number">5</span>, std: <span class="number">16682177.0</span></span><br><span class="line">...</span><br><span class="line">layers: <span class="number">28</span>, std: <span class="number">8.295319341711625e+34</span></span><br><span class="line">layers: <span class="number">29</span>, std: <span class="number">1.2787049888311946e+36</span></span><br><span class="line">layers: <span class="number">30</span>, std: <span class="number">2.0164275976565801e+37</span></span><br><span class="line">layers: <span class="number">31</span>, std: nan</span><br><span class="line"></span><br><span class="line">output <span class="keyword">is</span> nan at <span class="number">31</span>th layers</span><br><span class="line"></span><br><span class="line">tensor([[ <span class="number">1.3354e+38</span>, <span class="number">-2.0165e+38</span>, <span class="number">-3.2402e+37</span>, ..., <span class="number">1.0439e+37</span>,</span><br><span class="line"> -inf, <span class="number">1.2574e+38</span>],</span><br><span class="line"> [ -inf, -inf, inf, ..., -inf,</span><br><span class="line"> -inf, inf],</span><br><span class="line"> [ <span class="number">1.2230e+37</span>, -inf, <span class="number">5.6356e+37</span>, ..., <span class="number">-1.2776e+38</span>,</span><br><span class="line"> inf, -inf],</span><br><span class="line"> ...,</span><br><span class="line"> [ <span class="number">2.1591e+37</span>, <span class="number">2.5838e+38</span>, <span class="number">-2.9146e+38</span>, ..., inf,</span><br><span class="line"> -inf, -inf],</span><br><span class="line"> [ inf, <span class="number">1.9056e+38</span>, -inf, ..., inf,</span><br><span class="line"> -inf, -inf],</span><br><span class="line"> [ -inf, inf, <span class="number">-1.7735e+38</span>, ..., <span class="number">4.8110e+37</span>,</span><br><span class="line"> inf, -inf]], grad_fn=<MmBackward>)</span><br></pre></td></tr></table></figure><p>可以看到,到第 30 层的时候,神经元的值已经非常大或非常小,终于在第 31 层的时候,神经元的值突破了存储精度的极限,只好变成了 <code>nan</code>。</p><p>我们知道,一组数的方差 $D$ 和 期望 $E$ 在 $X$ 与 $Y$ 相互独立的条件下满足下面的性质:<br>$$E(X*Y)=E(X)*E(Y)$$<br>$$D(X)=E(X^2)-[E(X)]^2$$<br>$$D(X+Y)=D(X)+D(Y)$$<br>所以有:<br>$$D(X*Y)=D(X)*D(Y)+D(X)*[E(Y)]^2+D(Y)*[E(X)]^2$$<br>当 $E(X)=0$,$E(Y)=0$ 的时候:<br>$$D(X*Y)=D(X)*D(Y)$$<br>在神经网络中,由于全连接层的性质<br>$$H_{11}=\sum^n_{i=0}X_I*W_{1i}$$<br>得<br>$$D(H_{11})=\sum^n_{i=0}D(X_i)*D(W_{1i})\<br>=n*(1*1)\<br>=n$$<br>因为 $X_i$ 服从一个方差为 1 的正态分布,而 $W_i$ 也服从一个方差为 1 的分布,所以 $D(H_{11})$ 的值就是神经元的个数,因此标准差就是 $\sqrt{n}$。而全连接的性质决定了第 $k$ 层的神经元的标准差为 $\sqrt{n^k}$,与上面例子中 256 个神经元的情况基本吻合。</p><p>为了让神经网络的神经元值稳定,我们希望将每一层神经元的方差维持在 1,这样每一次前向传播后的方差仍然是 1,使模型保持稳定。这被称为“方差一致性准则”。因为$D(H_{11})=n*D(X_i)*D(W_{1i})$,为了让 $D(H_i)=1$,我们只需要让 $D(W_i)=\frac{1}{n}$ 即 $std(W)=\sqrt{\frac{1}{n}}$。我们验证一下:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">MLP</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, neural_num, layers)</span>:</span></span><br><span class="line"> super(MLP, self).__init__()</span><br><span class="line"> self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=<span class="literal">False</span>) <span class="keyword">for</span> _ <span class="keyword">in</span> range(layers)])</span><br><span class="line"> self.neural_num = neural_num</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> <span class="keyword">for</span> (i, linear) <span class="keyword">in</span> enumerate(self.linears):</span><br><span class="line"> x = linear(x)</span><br><span class="line"> print(<span class="string">f'layers: <span class="subst">{i}</span>, std: <span class="subst">{x.std()}</span>'</span>)</span><br><span class="line"> <span class="keyword">if</span> torch.isnan(x.std()):</span><br><span class="line"> print(<span class="string">f'output is nan at <span class="subst">{i}</span>th layers'</span>)</span><br><span class="line"> <span class="keyword">break</span></span><br><span class="line"> </span><br><span class="line"> <span class="keyword">return</span> x</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">init</span><span class="params">(self)</span>:</span></span><br><span class="line"> <span class="keyword">for</span> m <span class="keyword">in</span> self.modules():</span><br><span class="line"> <span class="keyword">if</span> isinstance(m, nn.Linear):</span><br><span class="line"> nn.init.normal_(m.weight.data, std=np.sqrt(<span class="number">1</span>/self.neural_num))</span><br><span class="line">layers=<span class="number">100</span></span><br><span class="line">neural_num=<span class="number">256</span></span><br><span class="line">batch_size=<span class="number">16</span></span><br><span class="line"></span><br><span class="line">net = MLP(neural_num, layers)</span><br><span class="line">net.init()</span><br><span class="line"></span><br><span class="line">inputs = torch.randn(batch_size, neural_num)</span><br><span class="line">output = net(inputs)</span><br></pre></td></tr></table></figure><p>打印一下神经网络的神经元值:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br></pre></td><td class="code"><pre><span class="line">layers: <span class="number">0</span>, std: <span class="number">0.9983504414558411</span></span><br><span class="line">layers: <span class="number">1</span>, std: <span class="number">0.9868919253349304</span></span><br><span class="line">layers: <span class="number">2</span>, std: <span class="number">0.9728540778160095</span></span><br><span class="line">layers: <span class="number">3</span>, std: <span class="number">0.9823500514030457</span></span><br><span class="line">layers: <span class="number">4</span>, std: <span class="number">0.9672497510910034</span></span><br><span class="line">layers: <span class="number">5</span>, std: <span class="number">0.9902626276016235</span></span><br><span class="line">...</span><br><span class="line">layers: <span class="number">95</span>, std: <span class="number">1.0507267713546753</span></span><br><span class="line">layers: <span class="number">96</span>, std: <span class="number">1.0782362222671509</span></span><br><span class="line">layers: <span class="number">97</span>, std: <span class="number">1.1384222507476807</span></span><br><span class="line">layers: <span class="number">98</span>, std: <span class="number">1.1450780630111694</span></span><br><span class="line">layers: <span class="number">99</span>, std: <span class="number">1.138461709022522</span></span><br><span class="line">tensor([[<span class="number">-0.6622</span>, <span class="number">0.4439</span>, <span class="number">0.5704</span>, ..., <span class="number">-2.2066</span>, <span class="number">-1.1012</span>, <span class="number">0.0450</span>],</span><br><span class="line"> [<span class="number">-0.1037</span>, <span class="number">-0.3485</span>, <span class="number">-0.0313</span>, ..., <span class="number">-0.1562</span>, <span class="number">-0.0520</span>, <span class="number">0.6481</span>],</span><br><span class="line"> [ <span class="number">0.3136</span>, <span class="number">-0.0966</span>, <span class="number">-1.5647</span>, ..., <span class="number">-0.8760</span>, <span class="number">-0.7498</span>, <span class="number">0.6339</span>],</span><br><span class="line"> ...,</span><br><span class="line"> [<span class="number">-0.6644</span>, <span class="number">-0.4354</span>, <span class="number">0.8103</span>, ..., <span class="number">1.1510</span>, <span class="number">0.7699</span>, <span class="number">0.0607</span>],</span><br><span class="line"> [<span class="number">-0.7511</span>, <span class="number">-0.1086</span>, <span class="number">0.4008</span>, ..., <span class="number">1.5456</span>, <span class="number">0.6027</span>, <span class="number">-0.0303</span>],</span><br><span class="line"> [<span class="number">-0.5602</span>, <span class="number">-0.1664</span>, <span class="number">-0.9711</span>, ..., <span class="number">-1.0884</span>, <span class="number">-0.7040</span>, <span class="number">0.7415</span>]],</span><br><span class="line"> grad_fn=<MmBackward>)</span><br></pre></td></tr></table></figure><p>神经元的值果然是稳定的。</p><h1 id="3-torch-nn-init-calculate-gain"><a href="#3-torch-nn-init-calculate-gain" class="headerlink" title="3. torch.nn.init.calculate_gain"></a>3. <code>torch.nn.init.calculate_gain</code></h1><p>这个函数计算激活函数之前和之后的方差的比例变化。比如 $D(X)=1$ 经过 <code>rlue</code> 以后还是 1,所以它的增益是 1。PyTorch 给了常见的激活函数的变化增益:<br>|激活函数|变化增益|<br>|:–:|:–:|<br>|Linearity|1|<br>|ConvND|1|<br>|Sigmoid|1|<br>|Tanh|$\frac{5}{3}$|<br>|ReLU|$\sqrt{2}$|<br>|Leaky ReLU|$\sqrt{\frac{2}{1+negative_slope^2}}$|<br>这个函数的参数如下:<code>torch.nn.init.calculate_gain(nonlinearity, param=None)</code></p><ul><li><code>nonlinearity</code>:激活函数;</li><li><code>param</code>激活函数的参数。<h1 id="4-Xavier-initialization"><a href="#4-Xavier-initialization" class="headerlink" title="4. Xavier initialization"></a>4. Xavier initialization</h1>为了解决饱和激活函数里的权重初始化问题,2010 年 Glorot 和 Bengio 发表了<a href="http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf" target="_blank" rel="noopener" title="Understanding the difficulty of training deep feedforward neural networks">《Understanding the difficulty of training deep feedforward neural networks》</a> 论文,正式提出了 Xavier 初始化。Xavier 初始化通常使用均匀分布。由论文得,初始化后的张量中的值采样自 $U[-a,a]$ 且<br>$$a=\text{gain}\times\sqrt{\frac{6}{n_i+n{i+1}}}$$<br>均匀分布下的 Xavier 初始化函数为 <code>torch.nn.init.xavier_uniform_(tensor, gain=1)</code>。</li></ul><p>Xavier 初始化也可以采用正态分布的方式。其初始化后的张量中的值采样自 $U[-a,a]$ 且<br>$$a=\text{gain}\times\sqrt{\frac{2}{n_i+n{i+1}}}$$</p><h1 id="5-Kaiming-initialization"><a href="#5-Kaiming-initialization" class="headerlink" title="5. Kaiming initialization"></a>5. Kaiming initialization</h1><p>2011 年 ReLU 函数横空出世,Xavier 初始化对 ReLU 函数不再适用。2015 年,Kaiming He 提出了<a href="https://arxiv.org/pdf/1502.01852.pdf" title="Delving Deep into Rectifiers:Surpassing Human-Level Performance on ImageNet Classification">另一种初始化方法</a>来适应 ReLU:<br>$$a=\frac{2}{(1+a^2)*n_i}$$<br><code>a</code> 是 ReLU 上 $x<0$ 时的斜率。同样的,Kaiming 初始化也有均匀分布和正态分布两种:<br><code>torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')</code>:均匀分布的 Kaiming 初始化函数;</p><p><code>torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')</code>:正态分布的 Kaiming 初始化函数。</p><h1 id="6-其它初始化方法"><a href="#6-其它初始化方法" class="headerlink" title="6. 其它初始化方法"></a>6. 其它初始化方法</h1><ul><li><code>torch.nn.init.uniform_(tensor, a=0.0, b=1.0)</code>:初始化服从 <code>[a, b]</code> 范围的均匀分布;</li><li><code>torch.nn.init.normal_(tensor, mean=0.0, std=1.0)</code>:初始化服从 <code>mean=0</code>,<code>std=1</code> 时的正态分布;</li><li><code>torch.nn.init.constant_(tensor, val)</code>:初始化为任一常数;</li><li><code>torch.nn.init.ones_(tensor)</code>:初始化为 1;</li><li><code>torch.nn.init.zeros_(tensor)</code>初始化为 0;</li><li><code>torch.nn.init.eye_(tensor)</code>:初始化对角线为 1,其它为 0;</li><li><code>torch.nn.init.orthogonal_(tensor, gain=1)</code>:对张量的矩形区域进行初始化。由于张量都是矩形,个人理解是这个函数会将整个张量进行初始化。</li><li><code>torch.nn.init.sparse_(tensor, sparsity, std=0.01)</code>:以 <code>sparsity</code> 为概率将张量填充 0,剩余的元素的标准差为 <code>std</code>。</li></ul>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 7:torch.nn 总览 & nn.Linear & 常用激活函数</title>
<link href="2020/05/18/DL-PyTorch-%E6%8A%98%E6%A1%82-7%EF%BC%9Atorch-nn-%E6%80%BB%E8%A7%88-nn-Linear-%E5%B8%B8%E7%94%A8%E6%BF%80%E6%B4%BB%E5%87%BD%E6%95%B0/"/>
<url>2020/05/18/DL-PyTorch-%E6%8A%98%E6%A1%82-7%EF%BC%9Atorch-nn-%E6%80%BB%E8%A7%88-nn-Linear-%E5%B8%B8%E7%94%A8%E6%BF%80%E6%B4%BB%E5%87%BD%E6%95%B0/</url>
<content type="html"><![CDATA[<h1 id="1-torch-nn-总览"><a href="#1-torch-nn-总览" class="headerlink" title="1 torch.nn 总览"></a>1 <code>torch.nn</code> 总览</h1><p><code>PyTorch</code> 把与深度学习模型搭建相关的全部类全部在 <code>torch.nn</code> 这个子模块中。根据类的功能分类,常用的有如下十几个部分:</p><ul><li><strong>Containers</strong>:容器类,如 <code>torch.nn.Module</code>;</li><li><strong>Convolution Layers</strong>:卷积层,如 <code>torch.nn.Conv2d</code>;</li><li><strong>Pooling Layers</strong>:池化层,如 <code>torch.nn.MaxPool2d</code>;</li><li><strong>Non-linear activations</strong>:非线性激活层,如 <code>torch.nn.ReLU</code>;</li><li><strong>Normalization layers</strong>:归一化层,如 <code>torch.nn.BatchNorm2d</code>;</li><li><strong>Recurrent layers</strong>:循环神经层,如 <code>torch.nn.LSTM</code>;</li><li><strong>Transformer layers</strong>:transformer 层,如 <code>torch.nn.TransformerEncoder</code>;</li><li><strong>Linear layers</strong>:线性连接层,如 <code>torch.nn.Linear</code>;</li><li><strong>Dropout layers</strong>:dropout 层,如 <code>torch.nn.Dropout</code>;</li><li><strong>Sparse layers</strong>:稀疏层,如 <code>torch.nn.Embedding</code>;</li><li><strong>Vision layers</strong>:vision 层,如 <code>torch.nn.Upsample</code>;</li><li><strong>DataParallel layers</strong>:平行计算层,如 <code>torch.nn.DataParallel</code>;</li><li><strong>Utilities</strong>:其它功能,如 <code>torch.nn.utils.clip_grad_value_</code>。<a id="more"></a>而在 <code>torch.nn</code> 下面还有一个子模块 <code>torch.nn.functional</code>,基本上是 <code>torch.nn</code> 里对应类的函数,比如 <code>torch.nn.ReLU</code> 的对应函数是 <code>torch.nn.functional.relu</code>。为什么要这么做呢?<blockquote><p>你可能会疑惑为什么需要这两个功能如此相近的模块,其实这么设计是有其原因的。如果我们只保留 nn.functional 下的函数的话,在训练或者使用时,我们就要手动去维护 weight,bias,stride 这些中间量的值,这显然是给用户带来了不便。而如果我们只保留 nn 下的类的话,其实就牺牲了一部分灵活性,因为做一些简单的计算都需要创造一个类,这也与 PyTorch 的风格不符。(<a href="https://www.zhihu.com/question/66782101" target="_blank" rel="noopener" title="PyTorch 中,nn 与 nn.functional 有什么区别?">知乎回答</a>)</p></blockquote></li></ul><p><code>torch.nn</code> 可以被 <code>nn.Module</code> 识别,并成为网络组成的一部分;<code>torch.nn.functional</code> 则不行。比较以下两个模型:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="class"><span class="keyword">class</span> <span class="title">Simple</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></span><br><span class="line"><span class="meta">... </span> super(Simple, self).__init__()</span><br><span class="line"><span class="meta">... </span> self.fc = nn.Linear(<span class="number">10</span>, <span class="number">1</span>)</span><br><span class="line"><span class="meta">... </span> self.dropout = nn.Dropout(<span class="number">0.5</span>) <span class="comment"># 使用 nn.Dropout 类</span></span><br><span class="line"> </span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"><span class="meta">... </span> x = self.fc(x)</span><br><span class="line"><span class="meta">... </span> x = self.dropout(x)</span><br><span class="line"><span class="meta">... </span> <span class="keyword">return</span> x</span><br><span class="line"><span class="meta">>>> </span>simple = Simple()</span><br><span class="line"><span class="meta">>>> </span>print(simple)</span><br><span class="line">Simple(</span><br><span class="line"> (fc): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">1</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (dropout): Dropout(p=<span class="number">0.5</span>, inplace=<span class="literal">False</span>) <span class="comment">#可以被识别成一层</span></span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span><span class="class"><span class="keyword">class</span> <span class="title">Simple2</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></span><br><span class="line"><span class="meta">... </span> super(Simple2, self).__init__()</span><br><span class="line"><span class="meta">... </span> self.fc = nn.Linear(<span class="number">10</span>, <span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"><span class="meta">... </span> x = F.dropout(self.fc(x)) <span class="comment"># 使用 nn.functional.dropout,不能被识别</span></span><br><span class="line"><span class="meta">... </span> <span class="keyword">return</span> x</span><br><span class="line"><span class="meta">>>> </span>simple2 = Simple2()</span><br><span class="line"><span class="meta">>>> </span>print(simple2)</span><br><span class="line">Simple2(</span><br><span class="line"> (fc): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">1</span>, bias=<span class="literal">True</span>)</span><br><span class="line">)</span><br></pre></td></tr></table></figure><p>什么时候调用 <code>torch.nn</code>,什么时候调用 <code>torch.nn.functional</code> 呢?个人的经验是:不需要存储权重的时候使用 <code>torch.nn.functional</code>,需要存储权重的时候使用 <code>torch.nn</code> :</p><ul><li>层、dropout 使用 <code>torch.nn</code> ;</li><li>激活函数使用 <code>torch.nn.functional</code>。</li></ul><p>这里要额外说一下 dropout 层。理论上 dropout 没有权重,可以使用 <code>torch.nn.functional.dropout</code>,然而 dropout 有<code>train</code> 和 <code>eval</code> 模式,使用 <code>torch.nn.Dropout</code> 可以方便地对模式进行控制,而函数就不行。所以为了方便,推荐使用 <code>torch.nn.Dropout</code>。</p><p>以后若没有特殊说明,均在引入模块时省略 <code>torch</code> 模块名称。</p><h1 id="2-nn-Linear"><a href="#2-nn-Linear" class="headerlink" title="2. nn.Linear"></a>2. <code>nn.Linear</code></h1><p>线性连接层又叫做全连接层(fully connected layer),指的是通过矩阵乘法将前一层的矩阵变换为下一层的矩阵:<br>$$layer1*W+b=layer2$$<br><img src="https://img-blog.csdnimg.cn/20200514091940862.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDYxNDY4Nw==,size_16,color_FFFFFF,t_70#pic_center" alt="在这里插入图片描述"><br><code>W</code> 被称为全连接层的 weights,<code>b</code> 被称为全连接层的 bias。通常为了演示方便,我们忽略 bias。<br><code>layer1</code> 如果是一个 $m*n$ 的矩阵,$W$ 是一个 $n*k$ 的矩阵,那么下一层 <code>layer2</code> 就是一个 $m*k$ 的矩阵。<code>n</code> 称为输入特征数(<code>input size</code>),<code>k</code> 称为输出特征数(<code>output size</code>),那么这个线性连接层可以被这样初始化:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">fc = nn.Linear(input_size, output_size)</span><br></pre></td></tr></table></figure><p>multilayer perception(多层感知机,MLP)就是通过若干个全连接层组合而成的。但是事实证明 MLP 的性能并不好,为什么呢?假设一个 MLP 由三个全连接层组成,三层分别为<br>$$x_3=x_2*W_2$$<br>$$x_2=x_1*W_1$$<br>我们把第二个式子中的 $x_2$ 代入第一个式子,可得:<br>$$X_3=(x_1*W_1)*W_2=x_1*(W_1*W_2)$$<br>可见若干层全连接层相连,最终可以化简为一个全连接层。为了解决这个问题,激活函数(activation function)出现了。</p><h1 id="3-激活函数"><a href="#3-激活函数" class="headerlink" title="3. 激活函数"></a>3. 激活函数</h1><p>激活函数就是非线性连接层,通过非线性函数将一层变为另一层。常用的激活函数有 <code>sigmoid</code>,<code>tanh</code>,<code>relu</code> 及其变种。虽然 <code>torch.nn</code> 有激活函数层,因为激活函数比较轻量级,使用 <code>torch.nn.functional</code> 里的函数功能就足够了。通常我们将 <code>torch.nn.functional</code> 写成 <code>F</code>:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch.nn.functional <span class="keyword">as</span> F</span><br></pre></td></tr></table></figure><ul><li><code>F.sigmoid</code><br><img src="https://img-blog.csdnimg.cn/20200514092014944.jpg?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDYxNDY4Nw==,size_16,color_FFFFFF,t_70#pic_center" alt="在这里插入图片描述"><br><code>sigmoid</code> 又叫做 <code>logistic</code>,通常写作 $\sigma$,公式为<br>$$sigmoid(x)=\sigma(x)=\frac{1}{1+e^{-x}}$$<br><code>sigmoid</code> 的值域为 $(0,1)$,所以通常用于二分类问题:大于 $0.5$ 为一类,小于 $0.5$ 为另一类。<code>sigmoid</code> 的导数公式为<br>$$\sigma’(x)=\sigma(x)(1-\sigma(x))$$<br>导数的值域为 $(0,0.25)$。<code>sigmoid</code> 函数的特点为:</li></ul><ol><li>函数的值在 $(0,1)$ 之间,符合概率分布;</li><li>导数的值域为 $(0,0.25)$,容易造成梯度消失;</li><li>输出为非对称正值,破坏数据分布。</li></ol><ul><li><code>F.tanh</code><br><img src="https://img-blog.csdnimg.cn/20200514092037833.jpg?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDYxNDY4Nw==,size_16,color_FFFFFF,t_70#pic_center" alt="在这里插入图片描述"><br><code>tanh</code> 是正切函数,公式为<br>$$tanh(x)=\frac{sin(x)}{cos(x)}=\frac{e^x+e^{-x}}{e^x+e^{-x}}$$<br><code>tanh</code> 的值域为 $(0,1)$,对称分布。它的导数公式为<br>$$tanh’(x)=1-tanh^2(x)$$<br>导数的值域为 $(0,1)$。<code>tanh</code> 的特点为:</li></ul><ol><li>函数值域为 $(0,1)$,对称分布;</li><li>导数值域为 $(0,1)$,容易造成梯度消失。</li></ol><ul><li><code>F.relu</code><br><img src="https://img-blog.csdnimg.cn/20200514092058279.jpg?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDYxNDY4Nw==,size_16,color_FFFFFF,t_70#pic_center" alt="在这里插入图片描述"><br>为了解决上述两个激活函数容易产生梯度消失的问题,Rectified Linear Unit(<code>relu</code>) 横空出世了。它实际上是一个分段函数:<br>$$relu(x)=<br>\begin{cases}<br>0,\ x<0\<br>x,\ x>0<br>\end{cases}$$<br><code>relu</code> 的优点在于求导非常方便,而且非常稳定:<br>$$relu’(x)=<br>\begin{cases}<br>0,\ x<0\<br>\text{unidentified},\ x=0\<br>1,\ x>0<br>\end{cases}$$<br>缺点在于</li></ul><ol><li>当 $x<0$ 时导数为 0,神经元“死亡”,即不再更新;</li><li>虽然没有梯度消失的问题,但有梯度爆炸的问题。</li></ol><ul><li><code>F.leakyrelu</code><br><img src="https://img-blog.csdnimg.cn/20200514092117274.jpg?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDYxNDY4Nw==,size_16,color_FFFFFF,t_70#pic_center" alt="在这里插入图片描述"><br>为了解决 <code>relu</code> 的问题,对其稍加改动成为了 <code>leakyrelu</code>:<br>$$relu(x)=<br>\begin{cases}<br>0,\ x<0\<br>\alpha x,\ x>0<br>\end{cases}$$<br>$\alpha$ 是一个很小的数,通常是 0.01。这样它的导数就变成了<br>$$relu(x)=<br>\begin{cases}<br>0,\ x<0\<br>\alpha,\ x>0<br>\end{cases}$$</li></ul>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 6:torch.nn.Module</title>
<link href="2020/05/14/DL-PyTorch-%E6%8A%98%E6%A1%82-6%EF%BC%9Atorch-nn-Module/"/>
<url>2020/05/14/DL-PyTorch-%E6%8A%98%E6%A1%82-6%EF%BC%9Atorch-nn-Module/</url>
<content type="html"><![CDATA[<p>本文中,我们看一看如何构建模型。<br>创造一个模型分两步:构建模型和权值初始化。而构建模型又有“定义单独的网络层”和“把它们拼在一起”两步。</p><h1 id="1-torch-nn-Module"><a href="#1-torch-nn-Module" class="headerlink" title="1. torch.nn.Module"></a>1. <code>torch.nn.Module</code></h1><p><code>torch.nn.Module</code> 是所有 <code>torch.nn</code> 中的类的父类。我们来看一个非常简单的神经网络:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">SimpleNet</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> super(SimpleNet,self).__init__()</span><br><span class="line"> self.fc = nn.Linear(x.shape[<span class="number">0</span>], <span class="number">1</span>)</span><br><span class="line"> </span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> x = self.fc(x)</span><br><span class="line"> <span class="keyword">return</span> x</span><br></pre></td></tr></table></figure><a id="more"></a><p>我们随便喂给它一个张量,打印它的网络:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>simpleNet = SimpleNet(torch.tensor((<span class="number">10</span>, <span class="number">2</span>)))</span><br><span class="line"><span class="meta">>>> </span>print(simpleNet)</span><br><span class="line">SimpleNet(</span><br><span class="line"> (fc): Linear(in_features=<span class="number">2</span>, out_features=<span class="number">1</span>, bias=<span class="literal">True</span>)</span><br><span class="line">)</span><br></pre></td></tr></table></figure><p>所有自定义的神经网络都要继承 <code>torch.nn.Module</code>。定义单独的网络层在 <code>__init__</code> 函数中实现,把定义好的网络层拼接在一起在 <code>forward</code> 函数中实现。网络类有两个重要的函数:<code>parameters</code> 存储了模型的权重;<code>modules</code> 存储了模型的结构。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>list(simpleNet.modules())</span><br><span class="line">[SimpleNet(</span><br><span class="line"> (fc): Linear(in_features=<span class="number">2</span>, out_features=<span class="number">1</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> ),</span><br><span class="line"> Linear(in_features=<span class="number">2</span>, out_features=<span class="number">1</span>, bias=<span class="literal">True</span>)]</span><br><span class="line"> </span><br><span class="line"> >>> list(simpleNet.parameters())</span><br><span class="line">[Parameter containing:</span><br><span class="line"> tensor([[ <span class="number">0.1533</span>, <span class="number">-0.2574</span>]], requires_grad=<span class="literal">True</span>),</span><br><span class="line"> Parameter containing:</span><br><span class="line"> tensor([<span class="number">-0.1589</span>], requires_grad=<span class="literal">True</span>)]</span><br></pre></td></tr></table></figure><h1 id="2-torch-nn-Sequential"><a href="#2-torch-nn-Sequential" class="headerlink" title="2. torch.nn.Sequential"></a>2. <code>torch.nn.Sequential</code></h1><p>这是一个序列容器,既可以放在模型外面单独构建一个模型,也可以放在模型里面成为模型的一部分。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br><span class="line">28</span><br><span class="line">29</span><br><span class="line">30</span><br><span class="line">31</span><br></pre></td><td class="code"><pre><span class="line"><span class="comment"># 单独成为一个模型</span></span><br><span class="line">model1 = nn.Sequential(</span><br><span class="line"> nn.Conv2d(<span class="number">1</span>,<span class="number">20</span>,<span class="number">5</span>),</span><br><span class="line"> nn.ReLU(),</span><br><span class="line"> nn.Conv2d(<span class="number">20</span>,<span class="number">64</span>,<span class="number">5</span>),</span><br><span class="line"> nn.ReLU()</span><br><span class="line"> )</span><br><span class="line"><span class="comment"># 成为模型的一部分</span></span><br><span class="line"><span class="class"><span class="keyword">class</span> <span class="title">LeNetSequential</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, classes)</span>:</span></span><br><span class="line"> super(LeNetSequential, self).__init__()</span><br><span class="line"> self.features = nn.Sequential(</span><br><span class="line"> nn.Conv2d(<span class="number">3</span>, <span class="number">6</span>, <span class="number">5</span>),</span><br><span class="line"> nn.ReLU(),</span><br><span class="line"> nn.MaxPool2d(kernel_size=<span class="number">2</span>, stride=<span class="number">2</span>),</span><br><span class="line"> nn.Conv2d(<span class="number">6</span>, <span class="number">16</span>, <span class="number">5</span>),</span><br><span class="line"> nn.ReLU(),</span><br><span class="line"> nn.MaxPool2d(kernel_size=<span class="number">2</span>, stride=<span class="number">2</span>),)</span><br><span class="line"></span><br><span class="line"> self.classifier = nn.Sequential(</span><br><span class="line"> nn.Linear(<span class="number">16</span>*<span class="number">5</span>*<span class="number">5</span>, <span class="number">120</span>),</span><br><span class="line"> nn.ReLU(),</span><br><span class="line"> nn.Linear(<span class="number">120</span>, <span class="number">84</span>),</span><br><span class="line"> nn.ReLU(),</span><br><span class="line"> nn.Linear(<span class="number">84</span>, classes),)</span><br><span class="line"></span><br><span class="line"> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"> x = self.features(x)</span><br><span class="line"> x = x.view(x.size()[<span class="number">0</span>], <span class="number">-1</span>)</span><br><span class="line"> x = self.classifier(x)</span><br><span class="line"> <span class="keyword">return</span> x</span><br></pre></td></tr></table></figure><p>放在模型里面的话,模型还是需要 <code>__init__</code> 和 <code>forward</code> 函数。</p><p>这样构建出来的模型的层没有名字:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>model2 = nn.Sequential(</span><br><span class="line"><span class="meta">... </span> nn.Conv2d(<span class="number">1</span>,<span class="number">20</span>,<span class="number">5</span>),</span><br><span class="line"><span class="meta">... </span> nn.ReLU(),</span><br><span class="line"><span class="meta">... </span> nn.Conv2d(<span class="number">20</span>,<span class="number">64</span>,<span class="number">5</span>),</span><br><span class="line"><span class="meta">... </span> nn.ReLU()</span><br><span class="line"><span class="meta">... </span> )</span><br><span class="line"><span class="meta">>>> </span>model2</span><br><span class="line">Sequential(</span><br><span class="line"> (<span class="number">0</span>): Conv2d(<span class="number">1</span>, <span class="number">20</span>, kernel_size=(<span class="number">5</span>, <span class="number">5</span>), stride=(<span class="number">1</span>, <span class="number">1</span>))</span><br><span class="line"> (<span class="number">1</span>): ReLU()</span><br><span class="line"> (<span class="number">2</span>): Conv2d(<span class="number">20</span>, <span class="number">64</span>, kernel_size=(<span class="number">5</span>, <span class="number">5</span>), stride=(<span class="number">1</span>, <span class="number">1</span>))</span><br><span class="line"> (<span class="number">3</span>): ReLU()</span><br><span class="line">)</span><br></pre></td></tr></table></figure><p>为了方便区分不同的层,我们可以使用 <code>collections</code> 里的 <code>OrderedDict</code> 函数:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="keyword">from</span> collections <span class="keyword">import</span> OrderedDict</span><br><span class="line"><span class="meta">>>> </span>model3 = nn.Sequential(OrderedDict([</span><br><span class="line"><span class="meta">... </span> (<span class="string">'conv1'</span>, nn.Conv2d(<span class="number">1</span>,<span class="number">20</span>,<span class="number">5</span>)),</span><br><span class="line"><span class="meta">... </span> (<span class="string">'relu1'</span>, nn.ReLU()),</span><br><span class="line"><span class="meta">... </span> (<span class="string">'conv2'</span>, nn.Conv2d(<span class="number">20</span>,<span class="number">64</span>,<span class="number">5</span>)),</span><br><span class="line"><span class="meta">... </span> (<span class="string">'relu2'</span>, nn.ReLU())</span><br><span class="line"><span class="meta">... </span> ]))</span><br><span class="line"><span class="meta">>>> </span>model3</span><br><span class="line">Sequential(</span><br><span class="line"> (conv1): Conv2d(<span class="number">1</span>, <span class="number">20</span>, kernel_size=(<span class="number">5</span>, <span class="number">5</span>), stride=(<span class="number">1</span>, <span class="number">1</span>))</span><br><span class="line"> (relu1): ReLU()</span><br><span class="line"> (conv2): Conv2d(<span class="number">20</span>, <span class="number">64</span>, kernel_size=(<span class="number">5</span>, <span class="number">5</span>), stride=(<span class="number">1</span>, <span class="number">1</span>))</span><br><span class="line"> (relu2): ReLU()</span><br><span class="line">)</span><br></pre></td></tr></table></figure><h1 id="3-torch-nn-ModuleList"><a href="#3-torch-nn-ModuleList" class="headerlink" title="3. torch.nn.ModuleList"></a>3. <code>torch.nn.ModuleList</code></h1><p>将网络层存储进一个列表,可以使用列表生成式快速生成网络,生成的网络层可以被索引,也拥有列表的方法 <code>append</code>,<code>extend</code> 或 <code>insert</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="class"><span class="keyword">class</span> <span class="title">MyModule</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></span><br><span class="line"><span class="meta">... </span> super(MyModule, self).__init__()</span><br><span class="line"><span class="meta">... </span> self.linears = nn.ModuleList([nn.Linear(<span class="number">10</span>, <span class="number">10</span>) <span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">10</span>)])</span><br><span class="line"><span class="meta">... </span> self.linears.append(nn.Linear(<span class="number">10</span>, <span class="number">1</span>)) <span class="comment"># append</span></span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="keyword">for</span> i, l <span class="keyword">in</span> enumerate(self.linears):</span><br><span class="line"><span class="meta">... </span> x = self.linears[i // <span class="number">2</span>](x) + l(x)</span><br><span class="line"><span class="meta">... </span> <span class="keyword">return</span> x</span><br><span class="line"> </span><br><span class="line"><span class="meta">>>> </span>myModeul = MyModule()</span><br><span class="line"><span class="meta">>>> </span>myModeul</span><br><span class="line">MyModule(</span><br><span class="line"> (linears): ModuleList(</span><br><span class="line"> (<span class="number">0</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">1</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">2</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">3</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">4</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">5</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">6</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">7</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">8</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">9</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (<span class="number">10</span>): Linear(in_features=<span class="number">10</span>, out_features=<span class="number">1</span>, bias=<span class="literal">True</span>) <span class="comment"># append 进的层</span></span><br><span class="line"> )</span><br><span class="line">)</span><br></pre></td></tr></table></figure><h1 id="4-torch-nn-ModuleDict"><a href="#4-torch-nn-ModuleDict" class="headerlink" title="4. torch.nn.ModuleDict"></a>4. <code>torch.nn.ModuleDict</code></h1><p>这个函数与上面的 <code>torch.nn.Sequential(OrderedDict(...))</code> 的行为非常类似,并且拥有 <code>keys</code>,<code>values</code>,<code>items</code>,<code>pop</code>,<code>update</code> 等词典的方法:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br><span class="line">26</span><br><span class="line">27</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="class"><span class="keyword">class</span> <span class="title">MyDictDense</span><span class="params">(nn.Module)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self)</span>:</span></span><br><span class="line"><span class="meta">... </span> super(MyDictDense, self).__init__()</span><br><span class="line"><span class="meta">... </span> self.params = nn.ModuleDict({</span><br><span class="line"><span class="meta">... </span> <span class="string">'linear1'</span>: nn.Linear(<span class="number">512</span>, <span class="number">128</span>),</span><br><span class="line"><span class="meta">... </span> <span class="string">'linear2'</span>: nn.Linear(<span class="number">128</span>, <span class="number">32</span>)</span><br><span class="line"><span class="meta">... </span> })</span><br><span class="line"><span class="meta">... </span> self.params.update({<span class="string">'linear3'</span>: nn.Linear(<span class="number">32</span>, <span class="number">10</span>)}) <span class="comment"># 添加层</span></span><br><span class="line"></span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">forward</span><span class="params">(self, x, choice=<span class="string">'linear1'</span>)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="keyword">return</span> torch.mm(x, self.params[choice])</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>net = MyDictDense()</span><br><span class="line"><span class="meta">>>> </span>print(net)</span><br><span class="line">MyDictDense(</span><br><span class="line"> (params): ModuleDict(</span><br><span class="line"> (linear1): Linear(in_features=<span class="number">512</span>, out_features=<span class="number">128</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (linear2): Linear(in_features=<span class="number">128</span>, out_features=<span class="number">32</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> (linear3): Linear(in_features=<span class="number">32</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>)</span><br><span class="line"> )</span><br><span class="line">)</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(net.params.keys())</span><br><span class="line">odict_keys([<span class="string">'linear1'</span>, <span class="string">'linear2'</span>, <span class="string">'linear3'</span>])</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>print(net.params.items())</span><br><span class="line">odict_items([(<span class="string">'linear1'</span>, Linear(in_features=<span class="number">512</span>, out_features=<span class="number">128</span>, bias=<span class="literal">True</span>)), (<span class="string">'linear2'</span>, Linear(in_features=<span class="number">128</span>, out_features=<span class="number">32</span>, bias=<span class="literal">True</span>)), (<span class="string">'linear3'</span>, Linear(in_features=<span class="number">32</span>, out_features=<span class="number">10</span>, bias=<span class="literal">True</span>))])</span><br></pre></td></tr></table></figure><p>欢迎关注我的微信公众号“花解语 NLP”:<br><img src="https://img-blog.csdnimg.cn/20200514100635366.jpg#pic_center" alt="在这里插入图片描述"></p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 5:PyTorch 模块总览 & torch.utils.data</title>
<link href="2020/05/14/DL-PyTorch-%E6%8A%98%E6%A1%82-5%EF%BC%9APyTorch-%E6%A8%A1%E5%9D%97%E6%80%BB%E8%A7%88-torch-utils-data/"/>
<url>2020/05/14/DL-PyTorch-%E6%8A%98%E6%A1%82-5%EF%BC%9APyTorch-%E6%A8%A1%E5%9D%97%E6%80%BB%E8%A7%88-torch-utils-data/</url>
<content type="html"><![CDATA[<h1 id="1-PyTorch-模块总览"><a href="#1-PyTorch-模块总览" class="headerlink" title="1. PyTorch 模块总览"></a>1. PyTorch 模块总览</h1><p>前面用了四篇文章详细讲解了 tensor 的性质,本篇开始进入功能的介绍。相比 TensorFlow,PyTorch 是非常轻量级的:相比 TensorFlow 追求兼容并包,PyTorch 把外围功能放在了扩展包中,比如 <code>torchtext</code>,以保持主体的轻便。</p><a id="more"></a><p>纵观 PyTorch 的 API,其核心大概如下:</p><ol><li><code>torch.nn</code> & <code>torch.nn.functional</code>:构建神经网络</li><li><code>torch.nn.init</code>:初始化权重</li><li><code>torch.optim</code>:优化器</li><li><code>torch.utils.data</code>:载入数据</li></ol><p>可以说,掌握了上面四个模块和前文中提到的底层 API,至少 80% 的 PyTorch 任务都可以完成。剩下的外围事物则有如下的模块支持:</p><ol start="5"><li><code>torch.cuda</code>:管理 GPU 资源</li><li><code>torch.distributed</code>:分布式训练</li><li><code>torch.jit</code>:构建静态图提升性能</li><li><code>torch.tensorboard</code>:神经网络的可视化</li></ol><p>如果额外掌握了上面的四个的模块,PyTorch 就只剩下一些边边角角的特殊需求了。</p><p>下面我们来了解第一个功能包:<code>torch.utils.data</code>。这个功能包的作用是收集、打包数据,给数据索引,然后按照 batch 将数据分批喂给神经网络。</p><h1 id="2-torch-utils-data-综述"><a href="#2-torch-utils-data-综述" class="headerlink" title="2. torch.utils.data 综述"></a>2. <code>torch.utils.data</code> 综述</h1><p>PyTorch 数据读取的核心是 <code>torch.utils.data.DataLoader</code> 类。它是一个 数据迭代读取器,支持</p><ul><li>映射方式和迭代方式读取数据;</li><li>自定义数据读取顺序;</li><li>自动批;</li><li>单线程或多线程数据读取;</li><li>自动内存定位。</li></ul><p>所有上述功能都可以在 <code>torch.utils.data.DataLoader</code> 的变量中定义:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">DataLoader(dataset, batch_size=<span class="number">1</span>, shuffle=<span class="literal">False</span>, sampler=<span class="literal">None</span>,</span><br><span class="line"> batch_sampler=<span class="literal">None</span>, num_workers=<span class="number">0</span>, collate_fn=<span class="literal">None</span>,</span><br><span class="line"> pin_memory=<span class="literal">False</span>, drop_last=<span class="literal">False</span>, timeout=<span class="number">0</span>,</span><br><span class="line"> worker_init_fn=<span class="literal">None</span>)</span><br></pre></td></tr></table></figure><p>最重要的变量为 <code>dataset</code>,它指明了数据的来源。<code>DataLoader</code> 支持两种数据类型:</p><ul><li>映射风格的数据封装(map-style datasets):<br>这种数据结构拥有自定义的 <code>__getitem__()</code> 和 <code>__len__()</code> 属性,可以以“索引/值”的方式读取数据,对应 <code>torch.utils.data.Dataset</code> 类;</li><li>迭代风格的数据封装(iterable-style datasets):<br>这种数据结构拥有自定义的 <code>__iter__()</code> 属性,通常适用于不方便随机获取数据或不定长数据集的读取上,对应 <code>torch.utils.data.IterableDataset</code> 类。</li></ul><p>下面我们从顶层的 <code>torch.utils.data.DataLoader</code> 开始,然后一步一步深入到自定义的细节上。为了方便讨论,我们先人工构建一个数据集:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>samples = torch.arange(<span class="number">100</span>)</span><br><span class="line"><span class="meta">>>> </span>labels = torch.cat([torch.zeros(<span class="number">50</span>), torch.ones(<span class="number">50</span>)], dim=<span class="number">0</span>)</span><br></pre></td></tr></table></figure><h1 id="3-torch-utils-data-DataLoader-数据加载器"><a href="#3-torch-utils-data-DataLoader-数据加载器" class="headerlink" title="3. torch.utils.data.DataLoader 数据加载器"></a>3. <code>torch.utils.data.DataLoader</code> 数据加载器</h1><p>我们看一下常用的变量:</p><ul><li><code>dataset</code>:数据源;</li><li><code>batch_size</code>:一个整数,定义每一批读取的元素个数;</li><li><code>shuffle</code>:一个布尔值,定义是否随机读取;</li><li><code>sampler</code>:定义获取数据的策略,必须与 <code>shuffle</code> 互斥;</li><li><code>num_workers</code>:一个整数,读取数据使用的线程数;</li><li><code>collate_fn</code>:一个将读取的数据处理、聚合成一个一个 batch 的自定义函数;</li><li><code>drop_last</code>:一个布尔值,如果最后一批数据的个数不足 batch 的大小,是否保留这个 batch。</li></ul><p><code>dataset</code>, <code>sampler</code> 和 <code>collate_fn</code> 是自定义的类或功能,我们从后往前看。</p><h1 id="4-数据集的分割"><a href="#4-数据集的分割" class="headerlink" title="4. 数据集的分割"></a>4. 数据集的分割</h1><p>在介绍这三个变量以前,我们先看看如何将数据集分割,比如分成训练集和测试集。</p><ul><li><code>torch.utils.data.Subset(dataset, indices)</code></li></ul><p>这个函数可以根据索引将数据集分割。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>even = [i <span class="keyword">for</span> i <span class="keyword">in</span> range(<span class="number">100</span>) <span class="keyword">if</span> i % <span class="number">2</span> == <span class="number">0</span>]</span><br><span class="line"><span class="meta">>>> </span>new1 = torch.utils.data.Subset(samples, even)</span><br><span class="line"><span class="meta">>>> </span>print(new1[:<span class="number">5</span>])</span><br><span class="line">tensor([<span class="number">0</span>, <span class="number">2</span>, <span class="number">4</span>, <span class="number">6</span>, <span class="number">8</span>])</span><br></pre></td></tr></table></figure><ul><li><code>torch.utils.data.random_split(dataset, lengths)</code></li></ul><p>先将数据随机排列,然后按照指定的长度进行选择。长度的和必须等于数据集中的数据数量。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>train, test = torch.utils.data.random_split(samples, [<span class="number">90</span>, <span class="number">10</span>])</span><br><span class="line"><span class="meta">>>> </span>print(torch.tensor(test))</span><br><span class="line">tensor([<span class="number">79</span>, <span class="number">60</span>, <span class="number">98</span>, <span class="number">74</span>, <span class="number">31</span>, <span class="number">43</span>, <span class="number">21</span>, <span class="number">69</span>, <span class="number">55</span>, <span class="number">76</span>])</span><br></pre></td></tr></table></figure><h1 id="5-collate-fn-核对函数"><a href="#5-collate-fn-核对函数" class="headerlink" title="5. collate_fn 核对函数"></a>5. <code>collate_fn</code> 核对函数</h1><p>这个变量的功能是在数据被读取后,送进模型前对所有数据进行处理、打包。比如我们有一个不定长度的视频数据集或文本数据集,我们可以自定义一个函数将它们的长度归一化。比如:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = [[<span class="number">1</span>,<span class="number">2</span>,<span class="number">3</span>],[<span class="number">4</span>,<span class="number">5</span>],[<span class="number">6</span>,<span class="number">7</span>,<span class="number">8</span>,<span class="number">9</span>]]</span><br><span class="line"><span class="meta">>>> </span><span class="function"><span class="keyword">def</span> <span class="title">collate_fn</span><span class="params">(data)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="string">'''</span></span><br><span class="line"><span class="string"><span class="meta">... </span> padding data, so they have same length.</span></span><br><span class="line"><span class="string"><span class="meta">... </span> '''</span></span><br><span class="line"><span class="meta">... </span> max_len = max([len(feature) <span class="keyword">for</span> feature <span class="keyword">in</span> data])</span><br><span class="line"><span class="meta">... </span> new = torch.zeros(len(data), max_len)</span><br><span class="line"> </span><br><span class="line"><span class="meta">... </span> <span class="keyword">for</span> i <span class="keyword">in</span> range(len(data)):</span><br><span class="line"><span class="meta">... </span> tmp = torch.as_tensor(data[i])</span><br><span class="line"><span class="meta">... </span> j = len(tmp)</span><br><span class="line"><span class="meta">... </span> new[i][:j] = tmp</span><br><span class="line"> </span><br><span class="line"><span class="meta">... </span> <span class="keyword">return</span> new</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>collate_fn(a)</span><br><span class="line">tensor([[<span class="number">1.</span>, <span class="number">2.</span>, <span class="number">3.</span>, <span class="number">0.</span>],</span><br><span class="line"> [<span class="number">4.</span>, <span class="number">5.</span>, <span class="number">0.</span>, <span class="number">0.</span>],</span><br><span class="line"> [<span class="number">6.</span>, <span class="number">7.</span>, <span class="number">8.</span>, <span class="number">9.</span>]])</span><br></pre></td></tr></table></figure><p>将这个函数赋值给 <code>collate_fn</code>,在读取数据的时候就可以自动对数据进行 padding 并打包成一个 batch。</p><h1 id="6-sampler-采样器"><a href="#6-sampler-采样器" class="headerlink" title="6. sampler 采样器"></a>6. <code>sampler</code> 采样器</h1><p><code>sampler</code> 变量决定了数据读取的顺序。注意,<code>sampler</code> 只对 iterable-style datasets 有效。除了可以自定义采样器,Python 内置了几种不同的采样器:</p><ul><li><code>torch.utils.data.SequentialSampler(data_source)</code></li></ul><p>默认的采样器。</p><ul><li><code>torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)</code></li></ul><p>随机选择数据。可以指定一次读取 <code>num_samples</code> 个数据。<code>replacement</code> 为 <code>True</code> 的话可以指定 <code>num_samples</code>(我并不理解为什么)。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>batch = torch.utils.data.RandomSampler(samples, replacement=<span class="literal">True</span>, num_samples=<span class="number">5</span>) <span class="comment"># 生成一个迭代器</span></span><br><span class="line"><span class="meta">>>> </span>print(list(batch))</span><br><span class="line">[<span class="number">85</span>, <span class="number">70</span>, <span class="number">5</span>, <span class="number">63</span>, <span class="number">79</span>]</span><br></pre></td></tr></table></figure><p>我个人的理解是这个采样器仅对一个 batch 内的数据进行 shuffle。</p><p>还有三个采样器无法独立使用,必须先实例化,然后放进 <code>DataLoader</code>:</p><ul><li><code>torch.utils.data.SubsetRandomSampler(indices)</code>:先按照索引选取数据,然后随机排列。</li><li><code>torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)</code>:字面意思是按照概率选择不同类别的元素,不过暂时没有搞明白怎么用,先挖个坑。</li><li><code>torch.utils.data.BatchSampler(sampler, batch_size, drop_last)</code>:在一个 batch 中应用另外一个采样器。<h1 id="7-dataset-数据集生成器"><a href="#7-dataset-数据集生成器" class="headerlink" title="7. dataset 数据集生成器"></a>7. <code>dataset</code> 数据集生成器</h1></li><li><code>torch.utils.data.IterableDataset</code></li></ul><p>生成一个 iterable-style 的数据封装,可以实现多线程读取数据。不过官方文档是这么说,我暂时没有弄明白怎么用这个类。</p><ul><li><code>torch.utils.data.Dataset</code></li></ul><p>这个类需要覆写 <code>__getitem__</code> 和 <code>__len__</code> 属性。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="class"><span class="keyword">class</span> <span class="title">MyData</span><span class="params">(torch.utils.data.Dataset)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">__init__</span><span class="params">(self, data)</span>:</span></span><br><span class="line"><span class="meta">... </span> super(MyData, self).__init__()</span><br><span class="line"><span class="meta">... </span> self.data = data</span><br><span class="line"> </span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">__len__</span><span class="params">(self, data)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="keyword">return</span> len(self.data)</span><br><span class="line"> </span><br><span class="line"><span class="meta">... </span> <span class="function"><span class="keyword">def</span> <span class="title">__getitem__</span><span class="params">(self, index)</span>:</span></span><br><span class="line"><span class="meta">... </span> <span class="keyword">return</span> self.data[index]</span><br><span class="line"> </span><br><span class="line"><span class="meta">>>> </span>mydata = MyData(samples)</span><br><span class="line"><span class="meta">>>> </span>mydata[<span class="number">0</span>]</span><br><span class="line">tensor(<span class="number">0</span>)</span><br><span class="line"><span class="meta">>>> </span>mydata[<span class="number">10</span>:<span class="number">15</span>]</span><br><span class="line">tensor([<span class="number">10</span>, <span class="number">11</span>, <span class="number">12</span>, <span class="number">13</span>, <span class="number">14</span>])</span><br></pre></td></tr></table></figure><h1 id="8-总结"><a href="#8-总结" class="headerlink" title="8. 总结"></a>8. 总结</h1><p>选择让我们把所有知识应用一下。假设我们想以 10 为一个 batch,随机选择数据:</p><pre><code class="py"><span class="meta">>>> </span>train = MyData(samples)<span class="meta">>>> </span>ds = torch.utils.data.DataLoader(train[:], batch_size=<span class="number">10</span>, shuffle=<span class="literal">True</span>)<span class="meta">>>> </span><span class="keyword">for</span> _ <span class="keyword">in</span> range(<span class="number">5</span>):<span class="meta">... </span> print(next(iter(ds)))tensor([<span class="number">22</span>, <span class="number">44</span>, <span class="number">56</span>, <span class="number">38</span>, <span class="number">86</span>, <span class="number">47</span>, <span class="number">14</span>, <span class="number">63</span>, <span class="number">88</span>, <span class="number">64</span>])tensor([<span class="number">32</span>, <span class="number">38</span>, <span class="number">6</span>, <span class="number">64</span>, <span class="number">67</span>, <span class="number">91</span>, <span class="number">54</span>, <span class="number">3</span>, <span class="number">80</span>, <span class="number">22</span>])tensor([<span class="number">77</span>, <span class="number">98</span>, <span class="number">61</span>, <span class="number">7</span>, <span class="number">17</span>, <span class="number">97</span>, <span class="number">83</span>, <span class="number">50</span>, <span class="number">26</span>, <span class="number">42</span>])tensor([<span class="number">67</span>, <span class="number">13</span>, <span class="number">10</span>, <span class="number">83</span>, <span class="number">54</span>, <span class="number">11</span>, <span class="number">31</span>, <span class="number">78</span>, <span class="number">15</span>, <span class="number">36</span>])tensor([ <span class="number">2</span>, <span class="number">55</span>, <span class="number">87</span>, <span class="number">39</span>, <span class="number">61</span>, <span class="number">92</span>, <span class="number">0</span>, <span class="number">79</span>, <span class="number">69</span>, <span class="number">84</span>])</code></pre>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 4:torch.autograph</title>
<link href="2020/05/11/DL-PyTorch-%E6%8A%98%E6%A1%82-4%EF%BC%9Atorch-autograph/"/>
<url>2020/05/11/DL-PyTorch-%E6%8A%98%E6%A1%82-4%EF%BC%9Atorch-autograph/</url>
<content type="html"><![CDATA[<p>神经网络的训练过程其实就是一个不断更新权重的过程,而更新权重要使用反向传播,而反向传播的本质是求导数。<code>PyTorch.autograd</code> 应运而生,接管了神经网络中不断重复的求导数运算。</p><h1 id="1-计算图"><a href="#1-计算图" class="headerlink" title="1. 计算图"></a>1. 计算图</h1><p>一个深度学习模型是由“计算图”构成的。所谓计算图是一个有向无环图(directed acyclic graph)。数据是这个图的节点(node),运算是这个图的边(edge)。如下图所示:<br><img src="https://img-blog.csdnimg.cn/20200430055244494.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80NDYxNDY4Nw==,size_16,color_FFFFFF,t_70#pic_center" alt="计算图"></p><a id="more"></a><p>这张计算图的数学表达式为 $y=(x+w)*(w+1)$。其中,$x$、$w$ 和 $b$ 是由用户定义的,称为“叶子节点”(leaf node),可在 PyTorch 中加以验证:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line">a = torch.tensor([<span class="number">1.</span>])</span><br><span class="line">b = torch.tensor([<span class="number">2.</span>])</span><br><span class="line">c = a.add(b)</span><br><span class="line"></span><br><span class="line">a.is_leaf() <span class="comment"># True</span></span><br><span class="line">c.is_leaf() <span class="comment"># False</span></span><br></pre></td></tr></table></figure><p>计算图可以分为动态图与静态图两种。</p><h2 id="1-1-动态图"><a href="#1-1-动态图" class="headerlink" title="1.1 动态图"></a>1.1 动态图</h2><p>动态图的搭建过程与执行过程可以同时进行。PyTorch 默认采用动态图机制。我们看一个例子:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line">first_counter = torch.Tensor([<span class="number">0</span>])</span><br><span class="line">second_counter = torch.Tensor([<span class="number">10</span>])</span><br><span class="line"> </span><br><span class="line"><span class="keyword">while</span> (first_counter[<span class="number">0</span>] < second_counter[<span class="number">0</span>]): <span class="comment">#[0] 加不加没有影响</span></span><br><span class="line"> first_counter += <span class="number">2</span></span><br><span class="line"> second_counter += <span class="number">1</span></span><br><span class="line"> </span><br><span class="line">print(first_counter)</span><br><span class="line">print(second_counter)</span><br></pre></td></tr></table></figure><h2 id="1-2-静态图"><a href="#1-2-静态图" class="headerlink" title="1.2 静态图"></a>1.2 静态图</h2><p>静态图先创建计算图,然后执行计算图。计算图一经定义,无法改变。TensorFlow 2.0 以前以静态图为主。我们看同样的例子在 TensorFlow 2.0 以前是怎么搭建的:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> tensorflow <span class="keyword">as</span> tf</span><br><span class="line">first_counter = tf.constant(<span class="number">0</span>) <span class="comment"># 定义变量</span></span><br><span class="line">second_counter = tf.constant(<span class="number">10</span>) <span class="comment"># 定义变量</span></span><br><span class="line"></span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">cond</span><span class="params">(first_counter, second_counter, *args)</span>:</span> <span class="comment"># 定义条件</span></span><br><span class="line"> <span class="keyword">return</span> first_counter < second_counter</span><br><span class="line"><span class="function"><span class="keyword">def</span> <span class="title">body</span><span class="params">(first_counter, second_counter)</span>:</span> <span class="comment"># 定义条件</span></span><br><span class="line"> first_counter = tf.add(first_counter, <span class="number">2</span>)</span><br><span class="line"> second_counter = tf.add(second_counter, <span class="number">1</span>)</span><br><span class="line"> <span class="keyword">return</span> first_counter, second_counter</span><br><span class="line"> </span><br><span class="line">c1, c2 = tf.while_loop(cond, body, [first_counter, second_counter]) <span class="comment"># 定义循环</span></span><br><span class="line"></span><br><span class="line"><span class="keyword">with</span> tf.Session() <span class="keyword">as</span> sess: <span class="comment"># 建立会话执行计算图</span></span><br><span class="line"> counter_1_res, counter_2_res = sess.run([c1, c2])</span><br><span class="line"></span><br><span class="line">print(first_counter)</span><br><span class="line">print(second_counter)</span><br></pre></td></tr></table></figure><p>因为静态图在设计好以后不能改变,调试的过程中 debug 实在太痛苦了。所以 TensorFlow 2.0 开始默认使用动态图。</p><h2 id="1-3-计算图示例"><a href="#1-3-计算图示例" class="headerlink" title="1.3 计算图示例"></a>1.3 计算图示例</h2><p>假如我们想计算上面计算图中 $y=(x+w)*(w+1)$ 在 $x=2$,$w=1$ 时的导数:</p><ol><li>首先,我们将上式进行分解:<br>$$a=x+w$$<br>$$b=w+1$$<br>于是我们得<br>$$y=a*b$$<br>对上式求导有:<br>$$\frac{\partial y}{\partial w}=\frac{\partial y}{\partial a}\frac{\partial a}{\partial w}+\frac{\partial y}{\partial b}\frac{\partial b}{\partial w}$$<br>根据$y=a*b$,$a=x+w$ 和 $b=w+1$ 可知:<br>$$\frac{\partial y}{\partial a}=b=w+1$$<br>$$\frac{\partial a}{\partial w}=1$$<br>$$\frac{\partial y}{\partial b}=a=x+w$$<br>$$\frac{\partial b}{\partial w}=1$$<br>所以<br>$$\frac{\partial y}{\partial w}=(w+1)+(x+w)=2*1+2+1=5$$<br>在 PyTorch 中求导数非常简单,使用 <code>tensor.backward()</code>即可:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="keyword">import</span> torch</span><br><span class="line"></span><br><span class="line">x = torch.tensor([<span class="number">2.</span>], requires_grad=<span class="literal">True</span>) <span class="comment"># 开启导数追踪</span></span><br><span class="line">w = torch.tensor([<span class="number">1.</span>], requires_grad=<span class="literal">True</span>) <span class="comment"># 开启导数追踪</span></span><br><span class="line"></span><br><span class="line">a = w.add(x)</span><br><span class="line">b = w.add(<span class="number">1</span>)</span><br><span class="line">y = a.mul(b)</span><br><span class="line"></span><br><span class="line">y.backward() <span class="comment"># 求导</span></span><br><span class="line">print(w.grad)</span><br></pre></td></tr></table></figure><h1 id="2-derivative(导数)的概述"><a href="#2-derivative(导数)的概述" class="headerlink" title="2. derivative(导数)的概述"></a>2. derivative(导数)的概述</h1><blockquote><p>当函数 $f$ 的自变量在一点 $x_0$ 上产生一个增量 $h$ 时,函数输出值的增量与自变量增量 $h$ 的比值在 $h$ 趋于0时的极限如果存在,即为 $f$ 在 $x_0$ 处的导数,记作 $f’(x_0)$。如果函数的自变量和取值都是实数的话,那么函数在某一点的导数就是该函数所代表的曲线在这一点上的切线斜率。</p></blockquote><p align="right">-- Wikipedia</p></li></ol><p>如何求导数是中学的数学知识,这里不再过多赘述,仅仅提一点,对 $z=f(x,y)$ 求 $\frac{\partial x}{\partial z}$ 叫做 “$f$ 关于 $x$ 的偏导数”,此时 $y$ 被看成常量,在求导时消去。</p><h1 id="3-chain-rule"><a href="#3-chain-rule" class="headerlink" title="3. chain rule"></a>3. chain rule</h1><p>假如我们想对 $z=f(g(x))$ 求导,可以设 $y=g(x), z=f(y)$,则<br>$$\frac{\partial x}{\partial z}=\frac{\partial x}{\partial y}\cdot \frac{\partial y}{\partial z}$$</p><h1 id="4-张量的反向传播"><a href="#4-张量的反向传播" class="headerlink" title="4. 张量的反向传播"></a>4. 张量的反向传播</h1><p>张量的求导函数为:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">tensor.backward(gradient=<span class="literal">None</span>, retain_graph=<span class="literal">None</span>, create_graph=<span class="literal">False</span>)</span><br></pre></td></tr></table></figure><h2 id="4-1-运算结果为-0-维张量的反向传播"><a href="#4-1-运算结果为-0-维张量的反向传播" class="headerlink" title="4.1 运算结果为 0 维张量的反向传播"></a>4.1 运算结果为 0 维张量的反向传播</h2><p>我们自己创建的 tensor 叫做*创建变量*,通过运算生成的 tensor 叫做*结果变量*。tensor 的一个创建方法为</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">torch.tensor(data, dtype=<span class="literal">None</span>, device=<span class="literal">None</span>, requires_grad=<span class="literal">False</span>, pin_memory=<span class="literal">False</span>)</span><br></pre></td></tr></table></figure><p>别的不说,单单说 <code>requires_grad</code>。如果想求这个 tensor 的导数,这个变量必须设为 <code>True</code>。<br><code>requires_grad</code> 的默认值为 <code>False</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor(<span class="number">2.</span>)</span><br><span class="line"><span class="meta">>>> </span>a.requires_grad</span><br><span class="line"><span class="literal">False</span></span><br><span class="line"><span class="meta">>>> </span>a</span><br><span class="line">tensor(<span class="number">1.</span>)</span><br></pre></td></tr></table></figure><p>而所有基于叶子节点生成的 tenor 的 <code>requires_grad</code> 属性与叶子节点相同。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>b = a**<span class="number">2</span> + <span class="number">1</span></span><br><span class="line"><span class="meta">>>> </span>b.requires_grad</span><br><span class="line"><span class="literal">False</span></span><br><span class="line"><span class="meta">>>> </span>b</span><br><span class="line">tensor(<span class="number">5.</span>)</span><br></pre></td></tr></table></figure><p>如果没有在创建的时候显式声明 <code>requires_grad=True</code>,也可以在用之前临时声明:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a.requires_grad_(<span class="literal">True</span>)</span><br><span class="line"><span class="meta">>>> </span>a.requires_grad = <span class="literal">True</span> <span class="comment"># 另一种写法</span></span><br><span class="line"><span class="meta">>>> </span>a</span><br><span class="line">tensor(<span class="number">2.</span>, requires_grad=<span class="literal">True</span>)</span><br></pre></td></tr></table></figure><p>而因为 <code>b = a + 1</code>,此时 <code>b</code> 的属性变成了</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">tensor(<span class="number">5.</span>, grad_fn=<AddBackward0>)</span><br></pre></td></tr></table></figure><p>想对 b 求导,使用 <code>b.backward()</code> 即可:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>b.backward()</span><br></pre></td></tr></table></figure><p>查看 <code>a</code> 在 <code>a = 2</code> 处的导数,使用 <code>a.grad</code> 即可:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a.grad</span><br><span class="line">tensor(<span class="number">4.</span>)</span><br></pre></td></tr></table></figure><p>这个很好理解,$\frac{\partial a}{\partial b} = (a^2)’ = 2 * a = 2 * 2 = 4$。</p><h2 id="4-2-运算结果为-1-维以上张量的反向传播"><a href="#4-2-运算结果为-1-维以上张量的反向传播" class="headerlink" title="4.2 运算结果为 1 维以上张量的反向传播"></a>4.2 运算结果为 1 维以上张量的反向传播</h2><p>如果结果为1 维以上张量,直接求导会出错:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([<span class="number">1.</span>, <span class="number">2.</span>], requires_grad=<span class="literal">True</span>)</span><br><span class="line"><span class="meta">>>> </span>b = a**<span class="number">2</span> + <span class="number">1</span></span><br><span class="line"><span class="meta">>>> </span>b</span><br><span class="line">tensor([<span class="number">2.</span>, <span class="number">3.</span>], grad_fn=<AddBackward0>)</span><br><span class="line"><span class="meta">>>> </span>b.backward()</span><br><span class="line">---------------------------------------------------------------------------</span><br><span class="line">RuntimeError Traceback (most recent call last)</span><br><span class="line"><ipython-input<span class="number">-391</span>-a721975e1357> <span class="keyword">in</span> <module></span><br><span class="line">----> 1 b.backward()</span><br><span class="line">...</span><br><span class="line">RuntimeError: grad can be implicitly created only <span class="keyword">for</span> scalar outputs</span><br></pre></td></tr></table></figure><p>这是因为 <code>[2., 3.]</code> 没法求导。这时候就必须指定 <code>backward()</code> 中的 <code>gradient</code> 变量为一个与创建变量维度相同的变量作为权重,这里以 <code>torch.tensor([1., 1.])</code> 为例:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>b.backward(gradient=torch.tensor([<span class="number">1.</span>, <span class="number">1.</span>]))</span><br><span class="line"><span class="meta">>>> </span>b.backward(gradient=torch.ones_like([<span class="number">1.</span>, <span class="number">1.</span>])) <span class="comment"># 创建一个与 a 维度相同的全 1 张量</span></span><br><span class="line"><span class="meta">>>> </span>a.grad</span><br><span class="line">tensor([<span class="number">2.</span>, <span class="number">4.</span>])</span><br></pre></td></tr></table></figure><p>关于 <code>gradient</code> 的详细讨论可以参考<a href="https://zhuanlan.zhihu.com/p/29923090" target="_blank" rel="noopener">PyTorch 的 backward 为什么有一个 grad_variables 参数?</a> 和 <a href="https://zhuanlan.zhihu.com/p/29904755" target="_blank" rel="noopener">Autograd:PyTorch中的梯度计算</a> 两篇文章。</p><h1 id="5-张量的显式求导-torch-augograd-grad"><a href="#5-张量的显式求导-torch-augograd-grad" class="headerlink" title="5. 张量的显式求导 torch.augograd.grad"></a>5. 张量的显式求导 <code>torch.augograd.grad</code></h1><p>虽然我们可以通过 <code>b.backward()</code> 来计算 <code>a.grad</code> 的值,下面这个函数可以直接求得导数。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">torch.autograd.grad(outputs, inputs, grad_outputs=<span class="literal">None</span>, retain_graph=<span class="literal">None</span>, create_graph=<span class="literal">False</span>, only_inputs=<span class="literal">True</span>, allow_unused=<span class="literal">False</span>)</span><br></pre></td></tr></table></figure><p>以 $y=f(x)$ 为例,<code>inputs</code> 是 $x$,<code>outputs</code> 是 $y$。如果 $y$ 是 0 维张量,<code>grad_outputs</code> 可以忽略;否则需要为一个与 $x$ 维度相同的张量作为权重。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>x=torch.tensor([[<span class="number">1.</span>,<span class="number">2.</span>,<span class="number">3.</span>],[<span class="number">4.</span>,<span class="number">5.</span>,<span class="number">6.</span>]],requires_grad=<span class="literal">True</span>)</span><br><span class="line"><span class="meta">>>> </span>y=x+<span class="number">2</span></span><br><span class="line"><span class="meta">>>> </span>z=y*y*<span class="number">3</span></span><br><span class="line"><span class="meta">>>> </span>dzdx = torch.autograd.grad(inputs=x, outputs=z, grad_outputs=torch.ones_like(x))</span><br><span class="line"><span class="meta">>>> </span>print(dzdx)</span><br><span class="line">(tensor([[<span class="number">18.</span>, <span class="number">24.</span>, <span class="number">30.</span>],</span><br><span class="line"> [<span class="number">36.</span>, <span class="number">42.</span>, <span class="number">48.</span>]])</span><br></pre></td></tr></table></figure><p>假如我们对上面的 $z$ 求 $\frac{\partial x}{\partial z}$,结果为 $\frac{\partial x}{\partial z}=\frac{\partial x}{\partial y}\cdot\frac{\partial y}{\partial z}=1\cdot 2\cdot 3\cdot(x+2)$。假如我们想求 $\frac{\partial\partial x}{\partial\partial z}$ 即二阶偏导呢?会报错:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>dzdx = torch.autograd.grad(inputs=x, outputs=z, grad_outputs=torch.ones_like(x))</span><br><span class="line">---------------------------------------------------------------------------</span><br><span class="line">RuntimeError Traceback (most recent call last)</span><br><span class="line"><ipython-input<span class="number">-440</span><span class="number">-7</span>a6333e01d6f> <span class="keyword">in</span> <module></span><br><span class="line">----> 1 dzdx = torch.autograd.grad(inputs=x, outputs=z, grad_outputs=torch.ones_like(x))</span><br><span class="line">...</span><br><span class="line">RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=<span class="literal">True</span> when calling backward the first time.</span><br></pre></td></tr></table></figure><p>这是因为动态计算图的特点是使用完毕后会被释放,当我们对 <code>b</code> 求导的话,对 <code>b</code> 求导的计算图在使用完毕后就被释放了。如果我们想求二阶导数,需要设置 <code>retain_graph=True</code> 或 <code>create_graph=True</code>。<code>retain_graph</code> 为保存计算图,<code>create_graph</code> 为创建计算图,两者的作用是相同的,都可以保存当前计算图。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>dzdx = torch.autograd.grad(inputs=x, outputs=z, grad_outputs=torch.ones_like(x),create_graph=<span class="literal">True</span>)</span><br><span class="line"><span class="meta">>>> </span>dz2dx2 = torch.autograd.grad(inputs=x, outputs=dzdx, grad_outputs=torch.ones_like(x))</span><br><span class="line"><span class="meta">>>> </span>print(dz2dx2)</span><br><span class="line">(tensor([[<span class="number">6.</span>, <span class="number">6.</span>, <span class="number">6.</span>],</span><br><span class="line"> [<span class="number">6.</span>, <span class="number">6.</span>, <span class="number">6.</span>]]),)</span><br></pre></td></tr></table></figure><p>结果也很好理解,$\frac{\partial\partial x}{\partial\partial z}=1\cdot 2\cdot 3=6$。</p><h1 id="6-张量的显式反向传播计算torch-autograd-backward"><a href="#6-张量的显式反向传播计算torch-autograd-backward" class="headerlink" title="6. 张量的显式反向传播计算torch.autograd.backward"></a>6. 张量的显式反向传播计算<code>torch.autograd.backward</code></h1><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">torch.autograd.backward(tensors, grad_tensors=<span class="literal">None</span>, retain_graph=<span class="literal">None</span>, create_graph=<span class="literal">False</span>)</span><br></pre></td></tr></table></figure><p>以上面的 <code>a</code> 和 <code>b</code> 为例,<code>b.backward()</code> = <code>torch.autograd.backward(b)</code>。其中 <code>grad_tensors</code> 与 <code>b.backward()</code> 中的 <code>gradient</code> 变量作用相同;<code>retain_graph</code> 和 <code>create_graph</code> 与 <code>torch.augograd.grad</code> 中的同名变量相同,不再赘述。</p>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 3:张量的运算 2</title>
<link href="2020/05/11/DL-PyTorch-%E6%8A%98%E6%A1%82-3%EF%BC%9A%E5%BC%A0%E9%87%8F%E7%9A%84%E8%BF%90%E7%AE%97-2/"/>
<url>2020/05/11/DL-PyTorch-%E6%8A%98%E6%A1%82-3%EF%BC%9A%E5%BC%A0%E9%87%8F%E7%9A%84%E8%BF%90%E7%AE%97-2/</url>
<content type="html"><![CDATA[<p>接上文 <a href="https://vincent507cpu.github.io/2020/05/05/DL-PyTorch-折桂-2:张量的运算-1/#more">[DL] PyTorch 折桂 2:张量的运算 1</a></p><h1 id="3-math-操作"><a href="#3-math-操作" class="headerlink" title="3. math 操作"></a>3. math 操作</h1><h2 id="3-1-pointwize-操作"><a href="#3-1-pointwize-操作" class="headerlink" title="3.1 pointwize 操作"></a>3.1 pointwize 操作</h2><p>pointwise 指的是元素对元素。比如</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">A = torch.Tensor([a1, a2])</span><br><span class="line">A + <span class="number">2.</span> = torch.Tensor([a1+<span class="number">2.</span>, a2+<span class="number">2.</span>])</span><br></pre></td></tr></table></figure><a id="more"></a><h3 id="3-1-1-张量的四则运算"><a href="#3-1-1-张量的四则运算" class="headerlink" title="3.1.1 张量的四则运算"></a>3.1.1 张量的四则运算</h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line">torch.add(input, other, *, alpha=<span class="number">1</span>, out=<span class="literal">None</span>) <span class="comment"># 相加</span></span><br><span class="line">torch.sub(input, other, out=<span class="literal">None</span>) <span class="comment"># 相减</span></span><br><span class="line">torch.mul(input, other, out=<span class="literal">None</span>) <span class="comment"># 相乘</span></span><br><span class="line">torch.div(input, other, out=<span class="literal">None</span>) <span class="comment"># 相除</span></span><br></pre></td></tr></table></figure><p><code>torch.add()</code>比较特殊,它遵循如下公式:<br>$$out=input+alpha×other$$<br>所以 <code>torch.add(torch.tensor(1), torch.tensor(2), torch.tensor(3))</code> 的运算实际上是 $1+2*3=7$。</p><p>我们还有 <code>torch.addcdiv(input, tensor1, tensor2, *, value=1, out=None)</code>,对应的运算规则为<br>$$out=input+value*\frac{tensor1}{tensor2}$$<br><code>torch.addcmul(input, tensor1, tensor2, *, value=1, out=None)</code>,对应运算规则为<br>$$out=input+value*tensor1*tensor2$$</p><h3 id="3-1-2-指数、对数、幂函数的运算"><a href="#3-1-2-指数、对数、幂函数的运算" class="headerlink" title="3.1.2 指数、对数、幂函数的运算"></a>3.1.2 指数、对数、幂函数的运算</h3><p>两个<strong>指数函数</strong>:</p><ul><li><p><code>torch.exp(input, out=None)</code> 自然指数运算:<br>$$out=e^{input}$$</p></li><li><p><code>torch.pow(input, exponent, out=None)</code> 任意指数运算:<br>$$out=x^{exponent}$$</p></li></ul><p>四个<strong>对数函数</strong>:</p><ul><li><code>torch.log(input, out=None)</code> 自然对数运算:<br>$$out=log_e{input}$$</li><li><code>torch.log1p(input, out=None)</code> 自然对数运算:<br>$$out=log_e{(input+1)}$$</li><li><code>torch.log2(input, out=None)</code> 以 2 为底的对数运算:<br>$$out=log_2{input}$$</li><li><code>torch.log10(input, out=None)</code> 以 10 为底的对数运算:<br>$$out=log_{10}{input}$$</li></ul><h3 id="3-1-3-变换函数"><a href="#3-1-3-变换函数" class="headerlink" title="3.1.3 变换函数"></a>3.1.3 变换函数</h3><ul><li><code>torch.abs(input, out=None)</code>:返回张量的绝对值。</li><li><code>torch.ceil(input, out=None)</code>:对张量向上取整。</li><li><code>torch.floor(input, out=None)</code>:对张量向下取整。</li><li><code>torch.floor_divide(input, other, out=None)</code>:张量相除后向下取整。</li><li><code>torch.fmod(input, other, out=None)</code>:对张量取余。</li><li><code>torch.neg(input, out=None)</code>:取张量的相反数。</li><li><code>torch.round(input, out=None)</code>:对张量取整。</li><li><code>torch.sigmoid(input, out=None)</code>:对张量进行 sigmoid 计算。</li><li><code>torch.sqrt(input, out=None)</code>:对张量取平方根。</li><li><code>torch.square(input, out=None)</code>:对张量平方。</li><li><code>torch.sort(input, dim=-1, descending=False, out=None)</code>:返回张量的排序结果。</li></ul><h3 id="3-1-4-三角函数"><a href="#3-1-4-三角函数" class="headerlink" title="3.1.4 三角函数"></a>3.1.4 三角函数</h3><ul><li><code>torch.sin(input, out=None)</code>:正弦</li><li><code>torch.cos(input, out=None)</code>:余弦</li><li><code>torch.tan(input, out=None)</code>:正切</li></ul><h2 id="3-2-降维函数"><a href="#3-2-降维函数" class="headerlink" title="3.2 降维函数"></a>3.2 降维函数</h2><p>所谓降维,就是某个维度经过运算后返回值是一个张量。如果下述函数中的 <code>dim</code> 变量没有显式赋值,则对整个张量进行计算,返回一个值;若 <code>dim</code> 被显性赋值,则对该 <code>dim</code> 内的每组数据分别进行运算。<code>keepdim</code> 若为 <code>True</code>,每个运算结果为一个一维张量,实际上没有降维。</p><ul><li><code>torch.argmax(input, dim, keepdim=False)</code>:返回张量内最大元素的索引。</li><li><code>torch.argmin(input, dim, keepdim=False, out=None)</code>:返回张量内最小元素的索引。</li></ul><p>例子:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([[<span class="number">1</span>, <span class="number">3</span>, <span class="number">2</span>, <span class="number">4</span>], [<span class="number">9</span>, <span class="number">8</span>, <span class="number">7</span>, <span class="number">6</span>]])</span><br><span class="line"><span class="meta">>>> </span>torch.argmax(a, dim=<span class="number">1</span>)</span><br><span class="line">tensor([<span class="number">3</span>, <span class="number">0</span>])</span><br></pre></td></tr></table></figure><ul><li><code>torch.max(input, dim, keepdim=False, out=None)</code>:返回在指定维度内进行比较后的最大值。</li><li><code>torch.min(input, dim, keepdim=False, out=None)</code>:返回在指定维度内进行比较后的最小值。</li><li><code>torch.mean(input, dim, keepdim=False, out=None)</code>:返回张量内张量的平均数。</li><li><code>torch.median(input, dim=-1, keepdim=False, out=None)</code>:返回张量内张量的中位数。</li><li><code>torch.prod(input, dim, keepdim=False, dtype=None)</code>:返回张量内元素的乘积。</li><li><code>torch.std(input, dim, unbiased=True, keepdim=False, out=None)</code>:返回张量内的标准差。</li><li><code>torch.sum(input, dim, keepdim=False, dtype=None)</code>:返回张量内元素的和。</li><li><code>torch.var(input, dim, keepdim=False, unbiased=True, out=None)</code>:返回张量内元素的方差。</li></ul><p>例子:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.ones((<span class="number">4</span>, <span class="number">3</span>)) <span class="comment"># 4 x 3 的全 1 矩阵</span></span><br><span class="line"><span class="meta">>>> </span>torch.sum(a) <span class="comment"># 没有维度,对所有元素求和</span></span><br><span class="line">tensor(<span class="number">12.</span>)</span><br><span class="line"><span class="meta">>>> </span>torch.sum(a, dim=<span class="number">1</span>)</span><br><span class="line">tensor([<span class="number">3.</span>, <span class="number">3.</span>, <span class="number">3.</span>, <span class="number">3.</span>])</span><br><span class="line"><span class="meta">>>> </span>torch.sum(a, dim=<span class="number">1</span>, keepdim=<span class="literal">True</span>)</span><br><span class="line">tensor([[<span class="number">3.</span>],</span><br><span class="line"> [<span class="number">3.</span>],</span><br><span class="line"> [<span class="number">3.</span>],</span><br><span class="line"> [<span class="number">3.</span>]])</span><br></pre></td></tr></table></figure><h2 id="3-3-比较函数"><a href="#3-3-比较函数" class="headerlink" title="3.3 比较函数"></a>3.3 比较函数</h2><h3 id="返回索引的函数:"><a href="#返回索引的函数:" class="headerlink" title="返回索引的函数:"></a>返回索引的函数:</h3><ul><li><code>torch.argsort(input, dim=-1, descending=False)</code> 返回在指定维度中第几大/小索引的张量,默认升序比较最后一维:<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([[<span class="number">1</span>, <span class="number">3</span>, <span class="number">2</span>, <span class="number">4</span>], [<span class="number">9</span>, <span class="number">8</span>, <span class="number">7</span>, <span class="number">6</span>]])</span><br><span class="line"><span class="meta">>>> </span>torch.argsort(a)</span><br><span class="line">tensor([[<span class="number">0</span>, <span class="number">2</span>, <span class="number">1</span>, <span class="number">3</span>],</span><br><span class="line"> [<span class="number">3</span>, <span class="number">2</span>, <span class="number">1</span>, <span class="number">0</span>]])</span><br></pre></td></tr></table></figure></li></ul><h3 id="既返回值,又返回索引的函数:"><a href="#既返回值,又返回索引的函数:" class="headerlink" title="既返回值,又返回索引的函数:"></a>既返回值,又返回索引的函数:</h3><ul><li><code>torch.sort(input, dim=-1, descending=False, out=None)</code>:对张量进行排序。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>x = torch.randn(<span class="number">3</span>, <span class="number">4</span>)</span><br><span class="line"><span class="meta">>>> </span>sorted, indices = torch.sort(x)</span><br><span class="line"><span class="meta">>>> </span>sorted</span><br><span class="line">tensor([[<span class="number">-0.2162</span>, <span class="number">0.0608</span>, <span class="number">0.6719</span>, <span class="number">2.3332</span>],</span><br><span class="line"> [<span class="number">-0.5793</span>, <span class="number">0.0061</span>, <span class="number">0.6058</span>, <span class="number">0.9497</span>],</span><br><span class="line"> [<span class="number">-0.5071</span>, <span class="number">0.3343</span>, <span class="number">0.9553</span>, <span class="number">1.0960</span>]])</span><br><span class="line"><span class="meta">>>> </span>indices</span><br><span class="line">tensor([[ <span class="number">1</span>, <span class="number">0</span>, <span class="number">2</span>, <span class="number">3</span>],</span><br><span class="line"> [ <span class="number">3</span>, <span class="number">1</span>, <span class="number">0</span>, <span class="number">2</span>],</span><br><span class="line"> [ <span class="number">0</span>, <span class="number">3</span>, <span class="number">1</span>, <span class="number">2</span>]])</span><br></pre></td></tr></table></figure></li><li><code>torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)</code>:返回最大/最小的 <code>k</code> 个值和它们的索引。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>x = torch.arange(<span class="number">1.</span>, <span class="number">6.</span>)</span><br><span class="line"><span class="meta">>>> </span>x</span><br><span class="line">tensor([ <span class="number">1.</span>, <span class="number">2.</span>, <span class="number">3.</span>, <span class="number">4.</span>, <span class="number">5.</span>])</span><br><span class="line"><span class="meta">>>> </span>torch.topk(x, <span class="number">3</span>)</span><br><span class="line">torch.return_types.topk(values=tensor([<span class="number">5.</span>, <span class="number">4.</span>, <span class="number">3.</span>]), indices=tensor([<span class="number">4</span>, <span class="number">3</span>, <span class="number">2</span>]))</span><br></pre></td></tr></table></figure></li><li><code>torch.cummax(input, dim, out=None)</code>:值与索引为当前位置以前的最大值和最大值的索引。</li><li><code>torch.cummin(input, dim, out=None)</code>:值与索引为当前位置以前的最小值和最小值的索引。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.randn(<span class="number">10</span>)</span><br><span class="line"><span class="meta">>>> </span>a</span><br><span class="line">tensor([<span class="number">-0.3449</span>, <span class="number">-1.5447</span>, <span class="number">0.0685</span>, <span class="number">-1.5104</span>, <span class="number">-1.1706</span>, <span class="number">0.2259</span>, <span class="number">1.4696</span>, <span class="number">-1.3284</span>,</span><br><span class="line"> <span class="number">1.9946</span>, <span class="number">-0.8209</span>])</span><br><span class="line"><span class="meta">>>> </span>torch.cummax(a, dim=<span class="number">0</span>)</span><br><span class="line">torch.return_types.cummax(</span><br><span class="line"> values=tensor([<span class="number">-0.3449</span>, <span class="number">-0.3449</span>, <span class="number">0.0685</span>, <span class="number">0.0685</span>, <span class="number">0.0685</span>, <span class="number">0.2259</span>, <span class="number">1.4696</span>, <span class="number">1.4696</span>,</span><br><span class="line"> <span class="number">1.9946</span>, <span class="number">1.9946</span>]),</span><br><span class="line"> indices=tensor([<span class="number">0</span>, <span class="number">0</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">5</span>, <span class="number">6</span>, <span class="number">6</span>, <span class="number">8</span>, <span class="number">8</span>]))</span><br><span class="line"> </span><br><span class="line"><span class="meta">>>> </span>a = torch.randn(<span class="number">10</span>)</span><br><span class="line"><span class="meta">>>> </span>a</span><br><span class="line">tensor([<span class="number">-0.2284</span>, <span class="number">-0.6628</span>, <span class="number">0.0975</span>, <span class="number">0.2680</span>, <span class="number">-1.3298</span>, <span class="number">-0.4220</span>, <span class="number">-0.3885</span>, <span class="number">1.1762</span>,</span><br><span class="line"> <span class="number">0.9165</span>, <span class="number">1.6684</span>])</span><br><span class="line"><span class="meta">>>> </span>torch.cummin(a, dim=<span class="number">0</span>)</span><br><span class="line">torch.return_types.cummin(</span><br><span class="line"> values=tensor([<span class="number">-0.2284</span>, <span class="number">-0.6628</span>, <span class="number">-0.6628</span>, <span class="number">-0.6628</span>, <span class="number">-1.3298</span>, <span class="number">-1.3298</span>, <span class="number">-1.3298</span>, <span class="number">-1.3298</span>,</span><br><span class="line"> <span class="number">-1.3298</span>, <span class="number">-1.3298</span>]),</span><br><span class="line"> indices=tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">4</span>, <span class="number">4</span>, <span class="number">4</span>, <span class="number">4</span>, <span class="number">4</span>, <span class="number">4</span>]))</span><br></pre></td></tr></table></figure></li></ul><h3 id="比较两个张量的元素,返回包含每个元素间比较的最大-小值:"><a href="#比较两个张量的元素,返回包含每个元素间比较的最大-小值:" class="headerlink" title="比较两个张量的元素,返回包含每个元素间比较的最大/小值:"></a>比较两个张量的元素,返回包含每个元素间比较的最大/小值:</h3><ul><li><code>torch.max(input, other, out=None)</code></li><li><code>torch.min(input, other, out=None)</code></li></ul><p>这两个函数与上面的降维函数中的同名函数的区别在于上面的两个函数的输入是一个张量,这里是两个。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([[<span class="number">1</span>, <span class="number">3</span>, <span class="number">96</span>, <span class="number">97</span>], [<span class="number">98</span>, <span class="number">99</span>, <span class="number">7</span>, <span class="number">6</span>]])</span><br><span class="line"><span class="meta">>>> </span>b = torch.tensor([[<span class="number">100</span>, <span class="number">101</span>, <span class="number">-1</span>, <span class="number">-2</span>], [<span class="number">-3</span>, <span class="number">-4</span>, <span class="number">102</span>, <span class="number">103</span>]])</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>torch.max(a, b)</span><br><span class="line">tensor([[<span class="number">100</span>, <span class="number">101</span>, <span class="number">96</span>, <span class="number">97</span>],</span><br><span class="line"> [ <span class="number">98</span>, <span class="number">99</span>, <span class="number">102</span>, <span class="number">103</span>]])</span><br><span class="line"><span class="meta">>>> </span>torch.min(a, b)</span><br><span class="line">tensor([[ <span class="number">1</span>, <span class="number">3</span>, <span class="number">-1</span>, <span class="number">-2</span>],</span><br><span class="line"> [<span class="number">-3</span>, <span class="number">-4</span>, <span class="number">7</span>, <span class="number">6</span>]])</span><br></pre></td></tr></table></figure><h3 id="两个张量比较是否相同,返回一个布尔值:torch-equal-input-other"><a href="#两个张量比较是否相同,返回一个布尔值:torch-equal-input-other" class="headerlink" title="两个张量比较是否相同,返回一个布尔值:torch.equal(input, other)"></a>两个张量比较是否相同,返回一个布尔值:<code>torch.equal(input, other)</code></h3><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.equal(torch.tensor([<span class="number">1</span>, <span class="number">2</span>]), torch.tensor([<span class="number">1</span>, <span class="number">2</span>]))</span><br><span class="line"><span class="literal">True</span></span><br></pre></td></tr></table></figure><p>两个张量的元素之间互相比较,每个比较返回一个布尔值,最终返回一个与被比较元素形状相同的张量:</p><ul><li><code>torch.eq(input, other, out=None)</code>:如果 <code>input</code> 中的元素等于 <code>output</code> 中的对应元素,返回 <code>True</code>。</li><li><code>torch.ge(input, other, out=None)</code>:如果 <code>input</code> 中的元素大于等于 <code>output</code> 中的对应元素,返回 <code>True</code>。</li><li><code>torch.gt(input, other, out=None)</code>:如果 <code>input</code> 中的元素大于 <code>output</code> 中的对应元素,返回 <code>True</code>。</li><li><code>torch.le(input, other, out=None)</code>:如果 <code>input</code> 中的元素小于等于 <code>output</code> 中的对应元素,返回 <code>True</code>。</li><li><code>torch.lt(input, other, out=None)</code>:如果 <code>input</code> 中的元素小于 <code>output</code> 中的对应元素,返回 <code>True</code>。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([[<span class="number">1</span>, <span class="number">3</span>, <span class="number">96</span>, <span class="number">97</span>], [<span class="number">98</span>, <span class="number">99</span>, <span class="number">7</span>, <span class="number">6</span>]])</span><br><span class="line"><span class="meta">>>> </span>c = torch.tensor([[<span class="number">1</span>, <span class="number">3</span>, <span class="number">5</span>, <span class="number">7</span>], [<span class="number">98</span>, <span class="number">99</span>, <span class="number">100</span>, <span class="number">101</span>]])</span><br><span class="line"><span class="meta">>>> </span>torch.eq(a, c)</span><br><span class="line">tensor([[ <span class="literal">True</span>, <span class="literal">True</span>, <span class="literal">False</span>, <span class="literal">False</span>],</span><br><span class="line"> [ <span class="literal">True</span>, <span class="literal">True</span>, <span class="literal">False</span>, <span class="literal">False</span>]])</span><br></pre></td></tr></table></figure></li></ul><h1 id="4-随机函数"><a href="#4-随机函数" class="headerlink" title="4. 随机函数"></a>4. 随机函数</h1><p>所有随机函数都有一个 <code>generator</code> 变量用于指定随机种子。</p><ul><li><code>torch.manual_seed(seed)</code>:设置随机种子。</li><li><code>torch.bernoulli(input, *, generator=None, out=None)</code>:生成服从伯努利分布(二项式分布)的张量。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.empty(<span class="number">2</span>, <span class="number">2</span>).uniform_(<span class="number">0</span>, <span class="number">1</span>)</span><br><span class="line"><span class="meta">>>> </span>a</span><br><span class="line">tensor([[<span class="number">0.0117</span>, <span class="number">0.2281</span>],</span><br><span class="line"> [<span class="number">0.8750</span>, <span class="number">0.9974</span>]])</span><br><span class="line"><span class="meta">>>> </span>torch.bernoulli(a)</span><br><span class="line">tensor([[<span class="number">0.</span>, <span class="number">0.</span>],</span><br><span class="line"> [<span class="number">1.</span>, <span class="number">1.</span>]])</span><br></pre></td></tr></table></figure></li><li><code>torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None)</code>:生成符合多项式分布的张量。<code>input</code> 为多项式分布的权重,当 <code>replacement</code> 为 <code>False</code> 时,<code>num_samples</code> 的长度必须小于 <code>input</code>。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>weights = torch.tensor([<span class="number">1.</span>, <span class="number">2.</span>, <span class="number">3.</span>, <span class="number">4.</span>])</span><br><span class="line"><span class="meta">>>> </span>torch.multinomial(weights, <span class="number">10</span>, replacement=<span class="literal">True</span>)</span><br><span class="line">tensor([<span class="number">1</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">2</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">0</span>, <span class="number">2</span>])</span><br></pre></td></tr></table></figure></li><li><code>torch.normal(mean, std, size, *, out=None)</code>:生成服从均值为 <code>mean</code>,方差为 <code>std</code> 的正态分布张量。<code>mean</code> 和 <code>std</code> 可以省略一个,若想同时省略请使用 <code>torch.randn</code> 函数。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.normal(<span class="number">2</span>, <span class="number">1</span>, [<span class="number">2</span>, <span class="number">2</span>])</span><br><span class="line">tensor([[<span class="number">1.7697</span>, <span class="number">2.2627</span>],</span><br><span class="line"> [<span class="number">2.0743</span>, <span class="number">2.1683</span>]])</span><br></pre></td></tr></table></figure></li><li><code>torch.poisson(input *, generator=None)</code>:生成一个形状与 <code>input</code> 相同,服从<a href="https://zh.wikipedia.org/wiki/泊松分佈" target="_blank" rel="noopener" title="泊松分布">泊松分布</a>的张量。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.poisson(torch.tensor([<span class="number">2.</span>, <span class="number">2.</span>]))</span><br><span class="line">tensor([<span class="number">1.</span>, <span class="number">4.</span>])</span><br></pre></td></tr></table></figure></li><li><code>torch.rand(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)</code>:生成一个范围为 $[0,1)$ 的均匀分布的张量。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.rand((<span class="number">2</span>, <span class="number">2</span>))</span><br><span class="line">tensor([[<span class="number">0.2255</span>, <span class="number">0.5614</span>],</span><br><span class="line"> [<span class="number">0.7037</span>, <span class="number">0.2410</span>]])</span><br></pre></td></tr></table></figure></li><li><code>torch.randn(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)</code>:生成一个均值为 0,方差为 1 的标准正态分布的张量。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.randn((<span class="number">2</span>, <span class="number">2</span>))</span><br><span class="line">tensor([[ <span class="number">1.2622</span>, <span class="number">-1.3420</span>],</span><br><span class="line"> [<span class="number">-0.2331</span>, <span class="number">0.6151</span>]])</span><br></pre></td></tr></table></figure></li><li><code>torch.randint(low=0, high, size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)</code>:生成一个范围为 $[low, high)$ 内取整数的均匀分布的张量。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.randint(<span class="number">10</span>, (<span class="number">2</span>, <span class="number">2</span>))</span><br><span class="line">tensor([[<span class="number">6</span>, <span class="number">2</span>],</span><br><span class="line"> [<span class="number">2</span>, <span class="number">3</span>]])</span><br></pre></td></tr></table></figure></li><li><code>torch.randperm(n, out=None, dtype=torch.int64, layout=torch.strided, device=None, requires_grad=False)</code>:返回一个经过随机打乱顺序的张量。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.randperm(<span class="number">10</span>)</span><br><span class="line">tensor([<span class="number">8</span>, <span class="number">9</span>, <span class="number">5</span>, <span class="number">3</span>, <span class="number">2</span>, <span class="number">1</span>, <span class="number">0</span>, <span class="number">7</span>, <span class="number">4</span>, <span class="number">6</span>])</span><br></pre></td></tr></table></figure></li></ul>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 2:张量的运算 1</title>
<link href="2020/05/05/DL-PyTorch-%E6%8A%98%E6%A1%82-2%EF%BC%9A%E5%BC%A0%E9%87%8F%E7%9A%84%E8%BF%90%E7%AE%97-1/"/>
<url>2020/05/05/DL-PyTorch-%E6%8A%98%E6%A1%82-2%EF%BC%9A%E5%BC%A0%E9%87%8F%E7%9A%84%E8%BF%90%E7%AE%97-1/</url>
<content type="html"><![CDATA[<h1 id="1-tensor-API-总览"><a href="#1-tensor-API-总览" class="headerlink" title="1. tensor API 总览"></a>1. <code>tensor</code> API 总览</h1><p>根据官方文档,对 <code>tensor</code> 可以进行如下操作:</p><ol><li>creation 操作</li><li>indexing,slicing,joining 及 mutating 操作</li><li>math 操作</li></ol><ul><li>elementwise 操作</li><li>reduction 操作</li><li>comparison 操作</li><li>spectral 操作</li><li>linear algebra 操作</li></ul><ol start="4"><li>random sampling 操作</li><li>serialization 操作</li><li>parallelism 操作</li></ol><p>本文以及接下来的几篇文章会关注前四个操作,最后两个操作会在合适的时候提到。</p><a id="more"></a><h1 id="2-creation-操作"><a href="#2-creation-操作" class="headerlink" title="2. creation 操作"></a>2. creation 操作</h1><h1 id="2-1-转化一个已有数组"><a href="#2-1-转化一个已有数组" class="headerlink" title="2.1 转化一个已有数组"></a>2.1 转化一个已有数组</h1><p>把一个已有数组转化成 <code>Tensor</code>,通常有四种方法:</p><ul><li><code>torch.Tensor()</code></li><li><code>torch.tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False)</code></li><li><code>torch.from_numpy(ndarray)</code></li><li><code>torch.as_tensor(data, dtype=None, device=None)</code></li></ul><p>因为 <code>torch.from_numpy()</code> 只能转化一个 <code>numpy array</code>,所以先使用 <code>numpy</code> 创建一个数组。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br><span class="line">14</span><br><span class="line">15</span><br><span class="line">16</span><br><span class="line">17</span><br><span class="line">18</span><br><span class="line">19</span><br><span class="line">20</span><br><span class="line">21</span><br><span class="line">22</span><br><span class="line">23</span><br><span class="line">24</span><br><span class="line">25</span><br></pre></td><td class="code"><pre><span class="line">array = np.arange(<span class="number">5</span>)</span><br><span class="line"><span class="comment"># 方法 1</span></span><br><span class="line"><span class="meta">>>> </span>tensor1 = torch.Tensor(array)</span><br><span class="line"><span class="meta">>>> </span>print(tensor1)</span><br><span class="line">tensor([<span class="number">0.</span>, <span class="number">1.</span>, <span class="number">2.</span>, <span class="number">3.</span>, <span class="number">4.</span>])</span><br><span class="line"><span class="meta">>>> </span>print(tensor1.dtype)</span><br><span class="line">torch.float32</span><br><span class="line"><span class="comment"># 方法 2</span></span><br><span class="line"><span class="meta">>>> </span>tensor2 = torch.tensor(array)</span><br><span class="line"><span class="meta">>>> </span>print(tensor2)</span><br><span class="line">tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>])</span><br><span class="line"><span class="meta">>>> </span>print(tensor2.dtype)</span><br><span class="line">torch.int64</span><br><span class="line"><span class="comment"># 方法 3</span></span><br><span class="line"><span class="meta">>>> </span>tensor3 = torch.from_numpy(array)</span><br><span class="line"><span class="meta">>>> </span>print(tensor3)</span><br><span class="line">tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>])</span><br><span class="line"><span class="meta">>>> </span>print(tensor3.dtype)</span><br><span class="line">torch.int64</span><br><span class="line"><span class="comment"># 方法 4</span></span><br><span class="line"><span class="meta">>>> </span>tensor4 = torch.as_tensor(array)</span><br><span class="line"><span class="meta">>>> </span>print(tensor4)</span><br><span class="line">tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>])</span><br><span class="line"><span class="meta">>>> </span>print(tensor14.dtype)</span><br><span class="line">torch.int64</span><br></pre></td></tr></table></figure><p>可以看出,<code>torch.Tensor()</code> 没有保留数值类型,其它三个都保留了。这是因为 <code>torch.Tensor()</code> 实际上是一个类,传入的数据需要“初始化”;其它三个都是函数,而通过 <code>torch.Tensor()</code> 生成的张量的数据类型是由一个环境变量决定的,这个环境变量可以通过 <code>torch.set_default_tensor_type(t)</code> 这个函数来设定。那么新的张量与原来的数组是什么关系呢?</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>tensor1[<span class="number">0</span>] = <span class="number">100</span></span><br><span class="line"><span class="meta">>>> </span>print(array)</span><br><span class="line">[<span class="number">0</span> <span class="number">1</span> <span class="number">2</span> <span class="number">3</span> <span class="number">4</span>]</span><br><span class="line"><span class="meta">>>> </span>tensor2[<span class="number">1</span>] = <span class="number">100</span></span><br><span class="line"><span class="meta">>>> </span>print(array)</span><br><span class="line">[<span class="number">0</span> <span class="number">1</span> <span class="number">2</span> <span class="number">3</span> <span class="number">4</span>]</span><br><span class="line"><span class="meta">>>> </span>tensor3[<span class="number">2</span>] = <span class="number">100</span></span><br><span class="line"><span class="meta">>>> </span>print(array</span><br><span class="line">[ <span class="number">0</span> <span class="number">1</span> <span class="number">100</span> <span class="number">3</span> <span class="number">4</span>]</span><br><span class="line"><span class="meta">>>> </span>tensor4[<span class="number">3</span>] = <span class="number">100</span></span><br><span class="line"><span class="meta">>>> </span>print(array)</span><br><span class="line">[ <span class="number">0</span> <span class="number">1</span> <span class="number">100</span> <span class="number">100</span> <span class="number">4</span>]</span><br></pre></td></tr></table></figure><p><code>torch.Tensor()</code> 和 <code>torch.tensor()</code> 复制了原数组的数据,<code>torch.from_numpy()</code> 和 <code>torch.as_tensor()</code> 直接与原数组共享数据。</p><p>综上所述,需要创建新张量时,推荐使用 <code>torch.tensor()</code>;需要避免复制时,推荐使用 <code>torch.as_tensor()</code>。理由如下:</p><ol><li><code>torch.tensor()</code> 和 <code>torch.as_tensor()</code> 的 API 更丰富,可控制的属性更多;</li><li><code>torch.Tensor()</code> 会改变数据类型,<code>torch.from_numpy()</code> 可接受的变量有限。<h2 id="2-2-其它创建张量的方法"><a href="#2-2-其它创建张量的方法" class="headerlink" title="2.2 其它创建张量的方法"></a>2.2 其它创建张量的方法</h2><h3 id="2-2-1-创建全部值为定值的函数"><a href="#2-2-1-创建全部值为定值的函数" class="headerlink" title="2.2.1 创建全部值为定值的函数"></a>2.2.1 创建全部值为定值的函数</h3><figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line">torch.zeros(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) # 全为 0</span><br><span class="line"></span><br><span class="line">torch.ones(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) # 全为 1</span><br><span class="line"></span><br><span class="line">torch.empty(*size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) # 全为一个随机小数</span><br></pre></td></tr></table></figure>这三个函数的参数完全相同,放在一起说了:</li></ol><ul><li><code>*size</code>:新张量的形状;</li><li><code>out</code>:输出的已有张量名称;</li><li><code>dtype</code>:数据类型;</li><li><code>layout</code>:内存里的存储方式;</li><li><code>device</code>:存储设备;</li><li><code>require_grad</code>:是否追踪导数。</li></ul><p>最后一个函数 <code>torch.empty</code> 生成的所谓“小数”真的是是非常小的、接近 0 的数:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.empty(<span class="number">1</span>)</span><br><span class="line">tensor([<span class="number">2.0890e+20</span>])</span><br></pre></td></tr></table></figure><p>还有一个类似的函数,可以指定填充的数值:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">torch.full(size, fill_value, out=<span class="literal">None</span>, dtype=<span class="literal">None</span>, layout=torch.strided, device=<span class="literal">None</span>, requires_grad=<span class="literal">False</span>) <span class="comment"># 全为一个定值</span></span><br></pre></td></tr></table></figure><p>还可以根据已有的张量,按照该张量的形状生成相同形状的新张量:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line">torch.zeros_like(input, dtype=<span class="literal">None</span>, layout=<span class="literal">None</span>, device=<span class="literal">None</span>, requires_grad=<span class="literal">False</span>, memory_format=torch.preserve_format)</span><br><span class="line"></span><br><span class="line">torch.ones_like(input, dtype=<span class="literal">None</span>, layout=<span class="literal">None</span>, device=<span class="literal">None</span>, requires_grad=<span class="literal">False</span>, memory_format=torch.preserve_format)</span><br><span class="line"></span><br><span class="line">torch.empty_like(input, dtype=<span class="literal">None</span>, layout=<span class="literal">None</span>, device=<span class="literal">None</span>, requires_grad=<span class="literal">False</span>, memory_format=torch.preserve_format)</span><br><span class="line"></span><br><span class="line">torch.empty_like(input, dtype=<span class="literal">None</span>, layout=<span class="literal">None</span>, device=<span class="literal">None</span>, requires_grad=<span class="literal">False</span>, memory_format=torch.preserve_format)</span><br></pre></td></tr></table></figure><p>举个例子,假设 <code>a</code> 是一个 <code>(2, 2)</code> 的矩阵张量:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.ones(<span class="number">2</span>, <span class="number">2</span>)</span><br><span class="line"><span class="meta">>>> </span>b = torch.zeros_like(a)</span><br><span class="line"><span class="meta">>>> </span>print(a, b)</span><br><span class="line">tensor([[<span class="number">1.</span>, <span class="number">1.</span>],</span><br><span class="line"> [<span class="number">1.</span>, <span class="number">1.</span>]]) </span><br><span class="line">tensor([[<span class="number">0.</span>, <span class="number">0.</span>],</span><br><span class="line"> [<span class="number">0.</span>, <span class="number">0.</span>]])</span><br></pre></td></tr></table></figure><h3 id="2-2-2-创建一个元素间隔为常量的函数"><a href="#2-2-2-创建一个元素间隔为常量的函数" class="headerlink" title="2.2.2 创建一个元素间隔为常量的函数"></a>2.2.2 创建一个元素间隔为常量的函数</h3><ul><li><code>torch.arange(start=0, end, step=1, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)</code></li></ul><p>创建一个 1 维张量,范围为 <code>[start, end)</code>,<strong>步进</strong>为 <code>step</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.arange(<span class="number">4</span>)</span><br><span class="line">tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>])</span><br></pre></td></tr></table></figure><ul><li><code>torch.linspace(start, end, steps=100, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)</code></li></ul><p>创建一个 1 维张量,范围为 <code>[start, end]</code>,<strong>步数</strong>为 <code>step</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.linspace(<span class="number">0</span>, <span class="number">5</span>, <span class="number">6</span>)</span><br><span class="line">tensor([<span class="number">0.</span>, <span class="number">1.</span>, <span class="number">2.</span>, <span class="number">3.</span>, <span class="number">4.</span>, <span class="number">5.</span>])</span><br></pre></td></tr></table></figure><ul><li><code>torch.logspace(start, end, steps=100, base=10.0, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)</code></li></ul><p>创建一个 1 维张量,范围为 $[base^{start}, base^{end}]$,<strong>步数</strong>为 <code>step</code>。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.logspace(<span class="number">-10</span>, <span class="number">10</span>, <span class="number">5</span>, <span class="number">10</span>)</span><br><span class="line">tensor([<span class="number">1.0000e-10</span>, <span class="number">1.0000e-05</span>, <span class="number">1.0000e+00</span>, <span class="number">1.0000e+05</span>, <span class="number">1.0000e+10</span>])</span><br></pre></td></tr></table></figure><h2 id="2-2-3-创建一个对角张量"><a href="#2-2-3-创建一个对角张量" class="headerlink" title="2.2.3 创建一个对角张量"></a>2.2.3 创建一个对角张量</h2><ul><li><code>torch.eye(n, m=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)</code></li></ul><p>对角张量指张量的对角线上的元素为 1,其余值为 0 的张量。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.eye(<span class="number">3</span>)</span><br><span class="line">tensor([[ <span class="number">1.</span>, <span class="number">0.</span>, <span class="number">0.</span>],</span><br><span class="line"> [ <span class="number">0.</span>, <span class="number">1.</span>, <span class="number">0.</span>],</span><br><span class="line"> [ <span class="number">0.</span>, <span class="number">0.</span>, <span class="number">1.</span>]])</span><br></pre></td></tr></table></figure><h1 id="3-indexing,slicing,joining-及-mutating-操作"><a href="#3-indexing,slicing,joining-及-mutating-操作" class="headerlink" title="3. indexing,slicing,joining 及 mutating 操作"></a>3. indexing,slicing,joining 及 mutating 操作</h1><h2 id="3-1-indexing-操作"><a href="#3-1-indexing-操作" class="headerlink" title="3.1 indexing 操作"></a>3.1 indexing 操作</h2><ul><li>PyTorch 支持 Python 式的索引操作。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([[<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>], [<span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>]])</span><br><span class="line"><span class="meta">>>> </span>a[<span class="number">0</span>][<span class="number">1</span>]</span><br><span class="line">tensor(<span class="number">2</span>)</span><br></pre></td></tr></table></figure></li><li><code>torch.index_select(input, dim, index, out=None)</code></li></ul><p>根据指定索引在指定轴上索引。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([[<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>], [<span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>], [<span class="number">7</span>, <span class="number">8</span>, <span class="number">9</span>]])</span><br><span class="line"><span class="meta">>>> </span>indices = torch.tensor([<span class="number">0</span>, <span class="number">2</span>])</span><br><span class="line"><span class="meta">>>> </span>torch.index_select(a, <span class="number">0</span>, indices) <span class="comment"># 选取第 0 轴上第 0,2 个元素</span></span><br><span class="line">tensor([[<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>],</span><br><span class="line"> [<span class="number">7</span>, <span class="number">8</span>, <span class="number">9</span>]])</span><br></pre></td></tr></table></figure><ul><li><code>torch.masked_select(input, mask, out=None)</code></li></ul><p>在 1 位张量上以布尔值进行索引。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>])</span><br><span class="line"><span class="meta">>>> </span>mask = a.lt(<span class="number">2</span> <span class="comment"># 以“小于 2”为条件创建布尔值</span></span><br><span class="line"><span class="meta">>>> </span>torch.masked_select(a, mask)</span><br><span class="line">tensor([<span class="number">0</span>, <span class="number">1</span>])</span><br></pre></td></tr></table></figure><h2 id="3-2-slicing-操作"><a href="#3-2-slicing-操作" class="headerlink" title="3.2 slicing 操作"></a>3.2 slicing 操作</h2><ul><li>Python 原生 slicing 操作<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([[<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>], [<span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>], [<span class="number">7</span>, <span class="number">8</span>, <span class="number">9</span>]])</span><br><span class="line"><span class="meta">>>> </span>a[<span class="number">0</span>:<span class="number">2</span>, :]</span><br><span class="line">tensor([[<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>],</span><br><span class="line"> [<span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>]])</span><br></pre></td></tr></table></figure></li><li><code>torch.split(tensor, split_size_or_sections, dim=0)</code></li></ul><p>按照给定维度进行切片。如果 <code>split_size_or_sections</code> 是一个整数,则以该数字为单位进行切片。如果张量在该维的长度不能被整除,最后一片的尺寸会小。</p><p>如果 <code>split_size_or_sections</code> 是一个列表,张量会按每个元素值切片。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.arange(<span class="number">10</span>)</span><br><span class="line"><span class="meta">>>> </span>torch.split(a, <span class="number">2</span>, <span class="number">0</span>)</span><br><span class="line">(tensor([<span class="number">0</span>, <span class="number">1</span>]),</span><br><span class="line"> tensor([<span class="number">2</span>, <span class="number">3</span>]),</span><br><span class="line"> tensor([<span class="number">4</span>, <span class="number">5</span>]),</span><br><span class="line"> tensor([<span class="number">6</span>, <span class="number">7</span>]),</span><br><span class="line"> tensor([<span class="number">8</span>, <span class="number">9</span>]))</span><br><span class="line"><span class="meta">>>> </span>torch.split(a, [<span class="number">3</span>, <span class="number">7</span>], <span class="number">0</span>)</span><br><span class="line">(tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>]), tensor([<span class="number">3</span>, <span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>, <span class="number">7</span>, <span class="number">8</span>, <span class="number">9</span>]))</span><br></pre></td></tr></table></figure><ul><li><code>torch.chunk(input, chunks, dim=0)</code></li></ul><p>在给定维度上将张量切成 <code>chunk</code> 份。若张量长度不能整除,则最后一份的长度会小。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>torch.chunk(a, <span class="number">3</span>, <span class="number">0</span>)</span><br><span class="line">(tensor([<span class="number">0</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>]), tensor([<span class="number">4</span>, <span class="number">5</span>, <span class="number">6</span>, <span class="number">7</span>]), tensor([<span class="number">8</span>, <span class="number">9</span>]))</span><br></pre></td></tr></table></figure><h2 id="3-3-joining-操作"><a href="#3-3-joining-操作" class="headerlink" title="3.3 joining 操作"></a>3.3 joining 操作</h2><ul><li><code>torch.cat(tensors, dim=0, out=None)</code></li></ul><p>在不增加维度的情况下聚合若干个张量。</p><figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line">>>> x = torch.arange(6).reshape(2, 3)</span><br><span class="line">>>> torch.cat([x, x], dim=0) # 在第 0 轴聚合</span><br><span class="line">tensor([[0, 1, 2],</span><br><span class="line"> [3, 4, 5],</span><br><span class="line"> [0, 1, 2],</span><br><span class="line"> [3, 4, 5]])</span><br><span class="line">>>> torch.cat([x, x], dim=1) # 在第 1 轴聚合</span><br><span class="line">tensor([[0, 1, 2, 0, 1, 2],</span><br><span class="line"> [3, 4, 5, 3, 4, 5]])</span><br></pre></td></tr></table></figure><ul><li><code>torch.stack(tensors, dim=0, out=None)</code></li></ul><p>将两个张量叠加到一起,这会产生一个新的轴。</p><figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br></pre></td><td class="code"><pre><span class="line">>>> torch.stack([x, x], dim=0)</span><br><span class="line">tensor([[[0, 1, 2],</span><br><span class="line"> [3, 4, 5]],</span><br><span class="line"></span><br><span class="line"> [[0, 1, 2],</span><br><span class="line"> [3, 4, 5]]])</span><br><span class="line">>>> torch.stack([x, x], dim=1)</span><br><span class="line">tensor([[[0, 1, 2],</span><br><span class="line"> [0, 1, 2]],</span><br><span class="line"></span><br><span class="line"> [[3, 4, 5],</span><br><span class="line"> [3, 4, 5]]])</span><br></pre></td></tr></table></figure>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[DL] PyTorch 折桂 1:张量的性质</title>
<link href="2020/05/05/DL-PyTorch-%E6%8A%98%E6%A1%82-1%EF%BC%9A%E5%BC%A0%E9%87%8F%E7%9A%84%E6%80%A7%E8%B4%A8/"/>
<url>2020/05/05/DL-PyTorch-%E6%8A%98%E6%A1%82-1%EF%BC%9A%E5%BC%A0%E9%87%8F%E7%9A%84%E6%80%A7%E8%B4%A8/</url>
<content type="html"><![CDATA[<h1 id="1-张量"><a href="#1-张量" class="headerlink" title="1. 张量"></a>1. 张量</h1><p>张量就是在深度学习里,可以使用 GPU 运算的多维数组。</p><ul><li>0 维张量是一个标量(scalar);</li><li>1 维张量是一个矢量(vector);</li><li>2 维张量是一个矩阵(matrix);</li><li>3 维以上的张量没有通俗的表示。</li></ul><h1 id="2-张量的数据类型"><a href="#2-张量的数据类型" class="headerlink" title="2. 张量的数据类型"></a>2. 张量的数据类型</h1><p>张量一共有三种,整数型、浮点型和布尔型。整数型和浮点型张量的精度分别有 8 位、16 位、32 位、64 位。</p><table><thead><tr><th>类型</th><th>精度</th><th>表示</th></tr></thead><tbody><tr><td>整形</td><td>8 位</td><td>torch.int8</td></tr><tr><td></td><td>16 位</td><td>torch.int16 或 torch.short</td></tr><tr><td></td><td>32 位</td><td>torch.int 或 torch.int32</td></tr><tr><td></td><td>64 位</td><td>torch.int64 或torch.long</td></tr><tr><td>浮点型</td><td>16 位</td><td>torch.float16 或 torch.half</td></tr><tr><td></td><td>32 位</td><td>torch.float 或 torch.float32</td></tr><tr><td></td><td>64 位</td><td>torch.float64 或 torch.double</td></tr><tr><td>布尔型</td><td></td><td>torch.bool</td></tr></tbody></table><p>获得一个张量的数据类型可以通过 <code>Tensor.dtype</code> 实现;如果给这个表达式赋值则将这个张量的数据类型改为目标类型。</p><a id="more"></a><h1 id="3-PyTorch-的不同形态"><a href="#3-PyTorch-的不同形态" class="headerlink" title="3. PyTorch 的不同形态"></a>3. PyTorch 的不同形态</h1><p>PyTorch 是很灵活,可以通过不同方式达到同样的目的。</p><h2 id="3-1-函数功能:torch-function-与-Tensor-function"><a href="#3-1-函数功能:torch-function-与-Tensor-function" class="headerlink" title="3.1 函数功能:torch.function() 与 Tensor.function()"></a>3.1 函数功能:<code>torch.function()</code> 与 <code>Tensor.function()</code></h2><p>首先,让我们有一个约定:如果我们说 <code>Tensor.xxx()</code>,这个 <code>Tensor</code> 指的是一个具体的张量。</p><p>在 PyTorch 中,张量的很多运算既可以通过它自身的方法,也可以作为 PyTorch 中的一个低级函数来实现,比如两个张量 <code>a</code> 和 <code>b</code>相加,既可以写成 <code>torch.add(a, b)</code>,也可以写成 <code>a.add(b)</code>。</p><h2 id="3-2-赋值语句:"><a href="#3-2-赋值语句:" class="headerlink" title="3.2 赋值语句:"></a>3.2 赋值语句:</h2><p>很多张量的属性既可以在创建时声明,也可以在之后任何时间声明。比如我想把一个值为 <code>1</code> 的 32 位整数张量赋给变量 <code>a</code>,我可以在生成时一步到位,</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">a = torch.tensor(<span class="number">1</span>, dtype=torch.int32)</span><br></pre></td></tr></table></figure><p>也可以先生成 <code>a</code> 的张量,然后再改变它的数据类型。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line">a = torch.tenor(<span class="number">1</span>)</span><br><span class="line">a.dtype = torch.int32</span><br></pre></td></tr></table></figure><h1 id="4-张量的存储"><a href="#4-张量的存储" class="headerlink" title="4. 张量的存储"></a>4. 张量的存储</h1><p>张量存储在连续的内存中,被 <code>torch.Storage</code> 控制。一个 *Storage* 是一个一维的包含数据类型的内存块。一个 PyTorch 的 <code>Tensor</code> 本质上是一个能够索引一个 *Storage* 的视角。你可以访问一个 <code>Tensor</code> 的 *Storage*:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>points = torch.tensor([[<span class="number">1.0</span>, <span class="number">4.0</span>], [<span class="number">2.0</span>, <span class="number">1.0</span>], [<span class="number">3.0</span>, <span class="number">5.0</span>]])</span><br><span class="line"><span class="meta">>>> </span>points.storage()</span><br><span class="line"><span class="number">1.0</span></span><br><span class="line"><span class="number">4.0</span> </span><br><span class="line"><span class="number">2.0</span> </span><br><span class="line"><span class="number">1.0</span> </span><br><span class="line"><span class="number">3.0</span> </span><br><span class="line"><span class="number">5.0</span></span><br><span class="line">[torch.FloatStorage of size <span class="number">6</span>]</span><br></pre></td></tr></table></figure><p>你不能对一个 *Storage* 进行二维索引,因为 *Storage* 是一维的。因为 *Storage* 是一个张量的存储,修改它同样会改变张量本身。</p><h1 id="5-张量的-size,storage-offset-和-stride"><a href="#5-张量的-size,storage-offset-和-stride" class="headerlink" title="5. 张量的 size,storage offset 和 stride"></a>5. 张量的 size,storage offset 和 stride</h1><p>我们先定义一个张量:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>points = torch.tensor([[<span class="number">1.0</span>, <span class="number">4.0</span>], [<span class="number">2.0</span>, <span class="number">1.0</span>], [<span class="number">3.0</span>, <span class="number">5.0</span>]])</span><br><span class="line"><span class="meta">>>> </span>points</span><br><span class="line">tensor([[<span class="number">1.</span>, <span class="number">4.</span>],</span><br><span class="line"> [<span class="number">2.</span>, <span class="number">1.</span>],</span><br><span class="line"> [<span class="number">3.</span>, <span class="number">5.</span>]])</span><br></pre></td></tr></table></figure><h2 id="5-1-张量的-size"><a href="#5-1-张量的-size" class="headerlink" title="5.1 张量的 size"></a>5.1 张量的 size</h2><p>获得一个张量的形状有四种方法:</p><ol><li><p><code>Tensor.size()</code></p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>points.size()</span><br><span class="line">torch.Size([<span class="number">3</span>, <span class="number">2</span>])</span><br></pre></td></tr></table></figure></li><li><p><code>Tensor.shape</code></p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>points.shape</span><br><span class="line">torch.Size([<span class="number">3</span>, <span class="number">2</span>])</span><br></pre></td></tr></table></figure><p>可以看出,两者的区别在于 <code>Tensor.shape</code> 没有 <code>()</code>。</p></li><li><p><code>Tensor.numel()</code></p></li></ol><p>查看 tensor 内的元素个数。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>points.numel()</span><br><span class="line"><span class="number">6</span></span><br></pre></td></tr></table></figure><ol start="4"><li><code>Tensor.dim()</code> 或 <code>Tensor.ndim</code></li></ol><p>查看张量的维数,即有几维。</p><h2 id="5-2-张量的-storage-offset"><a href="#5-2-张量的-storage-offset" class="headerlink" title="5.2 张量的 storage offset"></a>5.2 张量的 storage offset</h2><p>查看张量内的相应元素与内存中第一个元素的相对位移。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>second_point = points[<span class="number">1</span>]</span><br><span class="line"><span class="meta">>>> </span>second_point.storage_offset()</span><br><span class="line"><span class="number">2</span></span><br></pre></td></tr></table></figure><p>因为 <code>points</code> 的 <code>storage</code> 是 <code>1.0, 4.0, 2.0, 1.0, 3.0, 5.0</code>,<code>second_point</code> 距离这个张量在内存中的第一个元素的距离是 2。</p><h2 id="5-3-张量的-stride"><a href="#5-3-张量的-stride" class="headerlink" title="5.3 张量的 stride"></a>5.3 张量的 stride</h2><p>指的是当索引增加 1 时,每个维度内需要跳过的元素个数,是一个元组。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>points.stride()</span><br><span class="line">(<span class="number">2</span>, <span class="number">1</span>)</span><br></pre></td></tr></table></figure><h1 id="6-张量的变形、升维与降维"><a href="#6-张量的变形、升维与降维" class="headerlink" title="6. 张量的变形、升维与降维"></a>6. 张量的变形、升维与降维</h1><h2 id="6-1-张量的变形:Tensor-view-,Tensor-reshape-或-Tensor-resize"><a href="#6-1-张量的变形:Tensor-view-,Tensor-reshape-或-Tensor-resize" class="headerlink" title="6.1 张量的变形:Tensor.view(),Tensor.reshape() 或 Tensor.resize()"></a>6.1 张量的变形:<code>Tensor.view()</code>,<code>Tensor.reshape()</code> 或 <code>Tensor.resize()</code></h2><p>括号里面的数值用小括号、中括号或者不用括号括起来都可以,维数自定,只要所有数字的乘积与原尺寸的乘积相同即可。<code>Tensor.view()</code> 和 <code>Tensor.reshape()</code> 的维度中可以有一个 -1,表示该维的长度由其他维度决定。<code>Tensor.resize()</code> 的维度中不能有 -1。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>points.reshape((<span class="number">1</span>, <span class="number">2</span>, <span class="number">1</span>, <span class="number">-1</span>))</span><br><span class="line">tensor([[[[<span class="number">1.</span>, <span class="number">4.</span>, <span class="number">2.</span>]],</span><br><span class="line"> [[<span class="number">1.</span>, <span class="number">3.</span>, <span class="number">5.</span>]]]])</span><br></pre></td></tr></table></figure><h2 id="6-2-张量的转置:Tensor-t-Tensor-T-或-Tensor-transpose-dim1-dim2"><a href="#6-2-张量的转置:Tensor-t-Tensor-T-或-Tensor-transpose-dim1-dim2" class="headerlink" title="6.2 张量的转置:Tensor.t() Tensor.T 或 Tensor.transpose(dim1, dim2)"></a>6.2 张量的转置:<code>Tensor.t()</code> <code>Tensor.T</code> 或 <code>Tensor.transpose(dim1, dim2)</code></h2><p><code>Tensor.t()</code>只能转置维度小于等于 2 的张量,转置第 0、1 维。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.arange(<span class="number">4</span>).reshape(<span class="number">2</span>, <span class="number">2</span>)</span><br><span class="line"><span class="meta">>>> </span>a.t()</span><br><span class="line">tensor([[<span class="number">0</span>, <span class="number">2</span>],</span><br><span class="line"> [<span class="number">1</span>, <span class="number">3</span>]])</span><br></pre></td></tr></table></figure><p><code>Tensor.T</code> 把整个张量的维度进行颠倒。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>new_points = points.reshape(<span class="number">1</span>, <span class="number">2</span>, <span class="number">-1</span>, <span class="number">1</span>, <span class="number">3</span>)</span><br><span class="line"><span class="meta">>>> </span>new_points.shape</span><br><span class="line">torch.Size([<span class="number">1</span>, <span class="number">2</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">3</span>])</span><br><span class="line"><span class="meta">>>> </span>new = new_points.T</span><br><span class="line"><span class="meta">>>> </span>new.shape</span><br><span class="line">torch.Size([<span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">1</span>])</span><br></pre></td></tr></table></figure><p>而 <code>Tensor.transpose(dim1, dim2)</code> 可以转置任意两个维度。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>new2 = new_points.transpose(<span class="number">1</span>, <span class="number">4</span>)</span><br><span class="line"><span class="meta">>>> </span>new2.shape</span><br><span class="line">torch.Size([<span class="number">1</span>, <span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>])</span><br></pre></td></tr></table></figure><h2 id="6-3-张量的降维:Tensor-squeeze"><a href="#6-3-张量的降维:Tensor-squeeze" class="headerlink" title="6.3 张量的降维:Tensor.squeeze()"></a>6.3 张量的降维:<code>Tensor.squeeze()</code></h2><p>所谓降维,就是消去元素个数为 1 的维度。可以指定想消去的维度,若该维度不能消去,则该命令无效,但是不报错。若没有指定维度,则消去所有长度为 1 的维度。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>new_points2 = new.squeeze(<span class="number">1</span>)</span><br><span class="line"><span class="meta">>>> </span>new_points2.shape <span class="comment"># 降维成功</span></span><br><span class="line">torch.Size([<span class="number">3</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">1</span>])</span><br><span class="line"><span class="meta">>>> </span>new_points3 = new.squeeze(<span class="number">0</span>)</span><br><span class="line"><span class="meta">>>> </span>new_points3.shape <span class="comment"># 降维失败</span></span><br><span class="line">torch.Size([<span class="number">3</span>, <span class="number">1</span>, <span class="number">1</span>, <span class="number">2</span>, <span class="number">1</span>])</span><br><span class="line"><span class="meta">>>> </span>new_points4 = new_points.squeeze()</span><br><span class="line"><span class="meta">>>> </span>new_points4.shape <span class="comment"># 降维成功</span></span><br><span class="line">torch.Size([<span class="number">2</span>, <span class="number">3</span>])</span><br></pre></td></tr></table></figure><h2 id="6-4-张量的升维:Tensor-unsqueeze"><a href="#6-4-张量的升维:Tensor-unsqueeze" class="headerlink" title="6.4 张量的升维:Tensor.unsqueeze()"></a>6.4 张量的升维:<code>Tensor.unsqueeze()</code></h2><p>升维必须指定增加的维度,必须在张量的已有维度 <code>(-dim-1, dim+1)</code> 之间。相当于在两个维度之间“加塞”,后面的维度顺移一位。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>new_points4.unsqueeze(<span class="number">2</span>).shape</span><br><span class="line">torch.Size([<span class="number">2</span>, <span class="number">3</span>, <span class="number">1</span>])</span><br></pre></td></tr></table></figure><h1 id="7-张量的复制与原地修改"><a href="#7-张量的复制与原地修改" class="headerlink" title="7. 张量的复制与原地修改"></a>7. 张量的复制与原地修改</h1><p>因为张量本质上是连续内存地址的索引,我们把一段内存赋值给一个变量,再赋值给另一个变量后,修改一个变量中的索引往往会改变另一个变量的相同索引:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>])</span><br><span class="line"><span class="meta">>>> </span>b = a</span><br><span class="line"><span class="meta">>>> </span>b[<span class="number">1</span>] = <span class="number">10</span></span><br><span class="line"><span class="meta">>>> </span>a, b</span><br><span class="line">(tensor([ <span class="number">1</span>, <span class="number">10</span>, <span class="number">3</span>, <span class="number">4</span>]), tensor([ <span class="number">1</span>, <span class="number">10</span>, <span class="number">3</span>, <span class="number">4</span>]))</span><br></pre></td></tr></table></figure><p>我们希望能够控制这种现象。</p><h2 id="7-1-张量的复制"><a href="#7-1-张量的复制" class="headerlink" title="7.1 张量的复制"></a>7.1 张量的复制</h2><p>使用 <code>Tensor.clone()</code> 复制一段内存上的数据到另一段内存上,这两个张量相互独立。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>a = torch.tensor([<span class="number">1</span>, <span class="number">2</span>, <span class="number">3</span>, <span class="number">4</span>])</span><br><span class="line"><span class="meta">>>> </span>b = a.clone()</span><br><span class="line"><span class="meta">>>> </span>b[<span class="number">1</span>] = <span class="number">10</span></span><br><span class="line"><span class="meta">>>> </span>a, b</span><br><span class="line">(tensor([ <span class="number">1</span>, <span class="number">10</span>, <span class="number">3</span>, <span class="number">4</span>]), tensor([ <span class="number">1</span>, <span class="number">10</span>, <span class="number">3</span>, <span class="number">4</span>]))</span><br></pre></td></tr></table></figure><h2 id="7-2-张量的原地修改"><a href="#7-2-张量的原地修改" class="headerlink" title="7.2 张量的原地修改"></a>7.2 张量的原地修改</h2><p>如果我们能够避免引入新张量,直接在原始张量上修改,不就可以避免混淆了吗?很多张量操作都支持原地(in-place)操作,只要在原始函数后面加上 <code>_</code> 就表明是原地修改。比如:</p><figure class="highlight plain"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br></pre></td><td class="code"><pre><span class="line">>>> a = torch.ones(2, 2) # 创建一个 2 x 2 的全 1 张量</span><br><span class="line">>>> a</span><br><span class="line">tensor([[1., 1.],</span><br><span class="line"> [1., 1.]])</span><br><span class="line">>>> a.add_(1) # 原地每个元素加 1</span><br><span class="line">>>> a</span><br><span class="line">tensor([[2., 2.],</span><br><span class="line"> [2., 2.]])</span><br></pre></td></tr></table></figure>]]></content>
<tags>
<tag> deep learning </tag>
<tag> PyTorch </tag>
</tags>
</entry>
<entry>
<title>[Python] 经验总结 1:数据框的切片</title>
<link href="2020/04/28/Python-%E7%BB%8F%E9%AA%8C%E6%80%BB%E7%BB%93-1%EF%BC%9A%E6%95%B0%E6%8D%AE%E6%A1%86%E7%9A%84%E5%88%87%E7%89%87/"/>
<url>2020/04/28/Python-%E7%BB%8F%E9%AA%8C%E6%80%BB%E7%BB%93-1%EF%BC%9A%E6%95%B0%E6%8D%AE%E6%A1%86%E7%9A%84%E5%88%87%E7%89%87/</url>
<content type="html"><![CDATA[<p>每个人都知道 Python 是一种高效、简洁、优雅的语言。然而 Python 也有很多坑,现在老宅开一个新系列,分享老宅在学习和实践中总结的经验和教训,不定期分享。</p><p>第一个经验就是要吐槽数据框的切片。Python 有很多第三方的模块(比如 pandas 这样的数据科学神器),对提升 Python 的实用性贡献很大。然而模块多就有一个副作用:语法的不一致性。老宅在学习 pandas 的过程中就被数据框切片的复杂语法搞得挠头。</p><a id="more"></a><p>本文参考了 <code>http://chris.friedline.net/2015-12-15-rutgers/lessons/python2/02-index-slice-subset.html</code>,特此致谢。</p><p>数据框的切片,是在列表的切片的基础上发展起来的。不过列表是一维,数据框是二维,因此数据框切片有自己独特的方法。所以数据框的切片有两个风格:原生风格和 pandas 风格(这两个风格是老宅自己总结的…)。在总结以前,我们先构建数据集:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br><span class="line">10</span><br><span class="line">11</span><br><span class="line">12</span><br><span class="line">13</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span><span class="keyword">import</span> pandas <span class="keyword">as</span> pd</span><br><span class="line"><span class="meta">>>> </span><span class="keyword">from</span> sklearn.datasets <span class="keyword">import</span> load_iris <span class="comment"># 载入 iris 数据集模块</span></span><br><span class="line"><span class="meta">>>> </span>iris = pd.DataFrame(load_iris()[<span class="string">"data"</span>]) <span class="comment"># 载入 iris 数据集并转化为列表</span></span><br><span class="line"><span class="meta">>>> </span>iris.columns = [<span class="string">"sepal_length"</span>, <span class="string">"sepal_width"</span>,</span><br><span class="line"><span class="meta">... </span> <span class="string">"petal_length"</span>, <span class="string">"petal_width"</span>] <span class="comment"># 定义列名</span></span><br><span class="line"><span class="meta">>>> </span><span class="keyword">from</span> string <span class="keyword">import</span> ascii_lowercase <span class="comment"># 载入字母表</span></span><br><span class="line"><span class="meta">>>> </span>idx = []</span><br><span class="line"><span class="meta">>>> </span><span class="keyword">for</span> i <span class="keyword">in</span> ascii_lowercase:</span><br><span class="line"><span class="meta">... </span> <span class="keyword">for</span> j <span class="keyword">in</span> ascii_lowercase:</span><br><span class="line"><span class="meta">... </span> idx.append(i + j)</span><br><span class="line"><span class="comment"># 创建字母表排列组合</span></span><br><span class="line"><span class="meta">>>> </span>iris.index = idx[:<span class="number">150</span>] <span class="comment"># 定义行名</span></span><br><span class="line"><span class="meta">>>> </span>iris.head()</span><br></pre></td></tr></table></figure><table><thead><tr><th></th><th>sepal_length</th><th>sepal_width</th><th>petal_length</th><th>petal_width</th></tr></thead><tbody><tr><td>aa</td><td>5.1</td><td>3.5</td><td>1.4</td><td>0.2</td></tr><tr><td>ab</td><td>4.9</td><td>3.0</td><td>1.4</td><td>0.2</td></tr><tr><td>ac</td><td>4.7</td><td>3.2</td><td>1.3</td><td>0.2</td></tr><tr><td>ad</td><td>4.6</td><td>3.1</td><td>1.5</td><td>0.2</td></tr><tr><td>ae</td><td>5.0</td><td>3.6</td><td>1.4</td><td>0.2</td></tr></tbody></table><h1 id="原生风格"><a href="#原生风格" class="headerlink" title="原生风格"></a>原生风格</h1><h2 id="切片单列"><a href="#切片单列" class="headerlink" title="切片单列"></a>切片单列</h2><ul><li><code>df.column</code> 方法</li></ul><p>直接在数据框后面使用 <code>.</code> 连接列名。例如:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>iris.sepal_length[<span class="number">1</span>:<span class="number">5</span>]</span><br><span class="line">ab <span class="number">4.9</span></span><br><span class="line">ac <span class="number">4.7</span></span><br><span class="line">ad <span class="number">4.6</span></span><br><span class="line">ae <span class="number">5.0</span></span><br><span class="line">Name: sepal_length, dtype: float64</span><br></pre></td></tr></table></figure><p>这个方法不需要用 <code>""</code> 括上列,非常方便。不过这样有个潜在的局限:如果列名里有空格,这个方法就不好用了,就要用下面的方法。</p><ul><li><code>df["column"]</code> 方法</li></ul><p>这个方法的好处是引号内可以有特殊符号,比如空格。这样切片稍微麻烦一点,但是还可以接受。</p><ul><li><code>df[["column"]]</code> 方法</li></ul><p>这个方法与上面一样都用来切片单列。有什么不同呢?请看下面的例子:</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br><span class="line">3</span><br><span class="line">4</span><br><span class="line">5</span><br><span class="line">6</span><br><span class="line">7</span><br><span class="line">8</span><br><span class="line">9</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>iris[<span class="string">"sepal_length"</span>][<span class="number">1</span>:<span class="number">5</span>]</span><br><span class="line"></span><br><span class="line">ab <span class="number">4.9</span></span><br><span class="line">ac <span class="number">4.7</span></span><br><span class="line">ad <span class="number">4.6</span></span><br><span class="line">ae <span class="number">5.0</span></span><br><span class="line">Name: sepal_length, dtype: float64</span><br><span class="line"></span><br><span class="line"><span class="meta">>>> </span>iris[[<span class="string">"sepal_length"</span>]][<span class="number">1</span>:<span class="number">5</span>]</span><br></pre></td></tr></table></figure><table><thead><tr><th></th><th>sepal_length</th></tr></thead><tbody><tr><td>ab</td><td>4.9</td></tr><tr><td>ac</td><td>4.7</td></tr><tr><td>ad</td><td>4.6</td></tr><tr><td>ae</td><td>5.0</td></tr></tbody></table><p>单中括号和双中括号的区别在于单中括号返回的是序列,而双中括号返回的是数据框。</p><h2 id="切片多列"><a href="#切片多列" class="headerlink" title="切片多列"></a>切片多列</h2><ul><li><code>df[["column1", "columns2"...]]</code> 方法<br>因为多个列组合在一起是一个数据框,所以必须使用双中括号来切片。列名要用引号括起来。</li><li><code>df[list]</code> 方法<br>这里就体现出不一致性了:假如我们先将想要切片的列放入一个列表,就可以使用单中括号,而且不需要使用引号。<figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br><span class="line">2</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>lst = [<span class="string">"sepal_length"</span>,<span class="string">"petal_length"</span>]</span><br><span class="line"><span class="meta">>>> </span>iris[lst].head()</span><br></pre></td></tr></table></figure></li></ul><table><thead><tr><th></th><th>sepal_length</th><th>petal_length</th></tr></thead><tbody><tr><td>aa</td><td>5.1</td><td>1.4</td></tr><tr><td>ab</td><td>4.9</td><td>1.4</td></tr><tr><td>ac</td><td>4.7</td><td>1.3</td></tr><tr><td>ad</td><td>4.6</td><td>1.5</td></tr><tr><td>ae</td><td>5.0</td><td>1.4</td></tr></tbody></table><h2 id="切片行"><a href="#切片行" class="headerlink" title="切片行"></a>切片行</h2><ul><li><p>使用索引切片<br>哪怕数据框的行已经有的自定义索引名,照样可以使用数字 0 - ~ 切片。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>iris[<span class="number">1</span>:<span class="number">5</span>]</span><br></pre></td></tr></table></figure><table><thead><tr><th></th><th>sepal length</th><th>sepal width</th><th>petal length</th><th>petal width</th></tr></thead><tbody><tr><td>ab</td><td>4.9</td><td>3.0</td><td>1.4</td><td>0.2</td></tr><tr><td>ac</td><td>4.7</td><td>3.2</td><td>1.3</td><td>0.2</td></tr><tr><td>ad</td><td>4.6</td><td>3.1</td><td>1.5</td><td>0.2</td></tr><tr><td>ae</td><td>5.0</td><td>3.6</td><td>1.4</td><td>0.2</td></tr></tbody></table></li><li><p>使用行名切片</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>iris[<span class="string">"ae"</span>:<span class="string">"ag"</span>]</span><br></pre></td></tr></table></figure><table><thead><tr><th></th><th>sepal_length</th><th>sepal_width</th><th>petal_length</th><th>petal_width</th></tr></thead><tbody><tr><td>ae</td><td>5.0</td><td>3.6</td><td>1.4</td><td>0.2</td></tr><tr><td>af</td><td>5.4</td><td>3.9</td><td>1.7</td><td>0.4</td></tr><tr><td>ag</td><td>4.6</td><td>3.4</td><td>1.4</td><td>0.3</td></tr></tbody></table></li></ul><p>行切片还有一个列切片不具备的功能:切片连续的行。如果数据框的行名和列名不一致,pandas 会自动判断你在切片行还是列。如果一致嘛…pandas 就不知所措了。这时候就要用到下面的 pandas 风格切片。</p><h1 id="pandas-风格切片"><a href="#pandas-风格切片" class="headerlink" title="pandas 风格切片"></a>pandas 风格切片</h1><h2 id="df-loc-“indexes”-“columns”-基于行、列的名称切片"><a href="#df-loc-“indexes”-“columns”-基于行、列的名称切片" class="headerlink" title="df.loc[“indexes”, “columns”] 基于行、列的名称切片"></a>df.loc[“indexes”, “columns”] 基于行、列的名称切片</h2><p>注意行和列都是用的复数形式,意味着可以同时切片多行或多列。同时也可以切片范围内的行或列,使用 <code>:</code> 即可。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>iris.loc[<span class="string">"ae"</span>:<span class="string">"ag"</span>, [<span class="string">"sepal_length"</span>,<span class="string">"petal_length"</span>]]</span><br></pre></td></tr></table></figure><table><thead><tr><th></th><th>sepal_length</th><th>petal_length</th></tr></thead><tbody><tr><td>ae</td><td>5.0</td><td>1.4</td></tr><tr><td>af</td><td>5.4</td><td>1.7</td></tr><tr><td>ag</td><td>4.6</td><td>1.4</td></tr></tbody></table><p>想切片全部的行或列,只需要单独使用 <code>:</code> 即可。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line">iris.loc[<span class="string">"ae"</span>:<span class="string">"ag"</span>, :] <span class="comment"># 切片全部列</span></span><br></pre></td></tr></table></figure><table><thead><tr><th></th><th>sepal_length</th><th>sepal_width</th><th>petal_length</th><th>petal_width</th></tr></thead><tbody><tr><td>ae</td><td>5.0</td><td>3.6</td><td>1.4</td><td>0.2</td></tr><tr><td>af</td><td>5.4</td><td>3.9</td><td>1.7</td><td>0.4</td></tr><tr><td>ag</td><td>4.6</td><td>3.4</td><td>1.4</td><td>0.3</td></tr></tbody></table><h2 id="df-iloc-“indexes”-“columns”-基于行、列的索引切片"><a href="#df-iloc-“indexes”-“columns”-基于行、列的索引切片" class="headerlink" title="df.iloc[“indexes”, “columns”] 基于行、列的索引切片"></a>df.iloc[“indexes”, “columns”] 基于行、列的索引切片</h2><p>也可以基于行或列的数字索引切片,具备 <code>loc</code> 的一切结构和性质。</p><figure class="highlight py"><table><tr><td class="gutter"><pre><span class="line">1</span><br></pre></td><td class="code"><pre><span class="line"><span class="meta">>>> </span>iris.iloc[<span class="number">1</span>:<span class="number">3</span>, <span class="number">0</span>:<span class="number">2</span>]</span><br></pre></td></tr></table></figure><table><thead><tr><th></th><th>sepal_length</th><th>sepal_width</th></tr></thead><tbody><tr><td>ab</td><td>4.9</td><td>3.0</td></tr><tr><td>ac</td><td>4.7</td><td>3.2</td></tr></tbody></table>]]></content>
<tags>
<tag> Python </tag>
</tags>
</entry>
<entry>
<title>Hello World</title>
<link href="2020/04/28/hello-world/"/>
<url>2020/04/28/hello-world/</url>
<content type="html"><![CDATA[<p>欢迎来到我的博客!</p><p>作为一名自然语言处理爱好者,这个博客将关注自然语言处理以及计算语言学方面的研究及实践。</p><p>我还有很多东西要学习。求知若饥,学习若愚。</p>]]></content>
<categories>
<category> others </category>
</categories>
</entry>
</search>