字词嵌入应该使用可能的字词进行计算as explained in this article。但是,当使用spaCy的保留词嵌入时,无法复制该词,即King - Man + Woman
和Queen
之间的差值不接近零。
import spacy
import en_core_web_sm
nlp = en_core_web_sm.load()
#spacy.load('en_core_web_md')
doc = nlp('queen king woman man')
queen, king, woman, man = doc[0].vector, doc[1].vector, doc[2].vector, doc[3].vector
vec = king - man + woman
vec - queen
结果是:
array([ 0.10928726, 1.5129069 , 0.22144175, -1.0195163 , -0.88018465,
1.0273552 , -0.42121184, -0.6132709 , -5.506116 , -1.8500991 ,
-0.15576434, -1.1081355 , 0.33168507, -3.3569758 , -3.671307 ,
0.41009247, 5.0559406 , 1.6673484 , 1.6196246 , 2.3392878 ,
-1.4170032 , 1.0845371 , 1.1150997 , 1.4959896 , -5.9387603 ,
2.71976 , -5.1596265 , -2.1413157 , -2.0650306 , -0.90464056,
-3.662921 , -1.9780679 , 0.3792592 , -1.1127007 , -2.763383 ,
-0.46687317, -3.3972526 , -1.0455723 , 4.713142 , -1.3429235 ,
1.4183658 , -1.38419 , 3.2157912 , 0.4593829 , 2.57287 ,
-5.232533 , 2.007104 , -0.03439535, -2.5858183 , 2.3942559 ,
-2.2274508 , 1.1235554 , 1.8343859 , -3.809722 , 2.3434563 ,
6.6838984 , -0.79330105, -0.3786683 , 0.5149512 , -2.567075 ,
-4.5407395 , 0.15355158, 0.4791546 , 2.6068583 , 0.06677404,
-0.36967564, -5.109796 , 0.45319676, 7.158951 , 1.0552151 ,
-0.72934663, 1.5460184 , -0.41246212, -3.068016 , -1.2780238 ,
-2.256475 , 0.20324552, -0.7423974 , 2.6825244 , -1.8383589 ,
2.2891805 , 1.542151 , -2.3867102 , 0.03401029, -0.70230985,
1.4130044 , -2.416402 , 0.6862675 , -2.270489 , 3.9625044 ,
2.463019 , 1.3068041 , 3.4472568 , 5.8497505 , 7.2417293 ,
-1.8955674 ], dtype=float32)
可能有什么问题?
事实证明,这是由于嵌入的质量所致。当使用较大的嵌入时,结果将越来越接近于零。
nlp = spacy.load('en_core_web_sm')
doc = nlp('queen king woman man')
queen, king, woman, man = doc[0].vector, doc[1].vector, doc[2].vector, doc[3].vector
vec = king - man + woman
vec - queen
array([ 0.10928726, 1.5129069 , 0.22144175, -1.0195163 , -0.88018465,
1.0273552 , -0.42121184, -0.6132709 , -5.506116 , -1.8500991 ,
-0.15576434, -1.1081355 , 0.33168507, -3.3569758 , -3.671307 ,
0.41009247, 5.0559406 , 1.6673484 , 1.6196246 , 2.3392878 ,
-1.4170032 , 1.0845371 , 1.1150997 , 1.4959896 , -5.9387603 ,
2.71976 , -5.1596265 , -2.1413157 , -2.0650306 , -0.90464056,
-3.662921 , -1.9780679 , 0.3792592 , -1.1127007 , -2.763383 ,
-0.46687317, -3.3972526 , -1.0455723 , 4.713142 , -1.3429235 ,
1.4183658 , -1.38419 , 3.2157912 , 0.4593829 , 2.57287 ,
-5.232533 , 2.007104 , -0.03439535, -2.5858183 , 2.3942559 ,
-2.2274508 , 1.1235554 , 1.8343859 , -3.809722 , 2.3434563 ,
6.6838984 , -0.79330105, -0.3786683 , 0.5149512 , -2.567075 ,
-4.5407395 , 0.15355158, 0.4791546 , 2.6068583 , 0.06677404,
-0.36967564, -5.109796 , 0.45319676, 7.158951 , 1.0552151 ,
-0.72934663, 1.5460184 , -0.41246212, -3.068016 , -1.2780238 ,
-2.256475 , 0.20324552, -0.7423974 , 2.6825244 , -1.8383589 ,
2.2891805 , 1.542151 , -2.3867102 , 0.03401029, -0.70230985,
1.4130044 , -2.416402 , 0.6862675 , -2.270489 , 3.9625044 ,
2.463019 , 1.3068041 , 3.4472568 , 5.8497505 , 7.2417293 ,
-1.8955674 ], dtype=float32)
nlp = spacy.load('en_core_web_md')
doc = nlp('queen king woman man')
queen, king, woman, man = doc[0].vector, doc[1].vector, doc[2].vector, doc[3].vector
vec = king - man + woman
vec - queen
array([ 0.10458702, -0.05152999, -0.01085299, 0.40603995, 0.111525 ,
0.03181005, -0.18277001, 0.10793996, 0.22586 , 0.42549992,
-0.620518 , 0.09305897, -0.0758817 , -0.29067168, -0.297841 ,
-0.43369 , -0.44859397, 0.21168 , -0.172735 , 0.24211 ,
0.20211 , -0.15502006, -0.04844499, -0.202636 , -0.21129996,
0.457768 , 0.03138995, 0.13294101, -0.534806 , -0.07134694,
-0.157518 , -0.05403006, -0.14246997, -0.773906 , 0.15866998,
-0.12601201, -0.19204 , -0.40347007, 0.05978 , 0.5203604 ,
0.37192 , -0.252379 , -0.097138 , -0.40504098, 0.25123 ,
-0.03785798, -0.11933102, -0.00672996, 0.40258 , 0.02721703,
-0.29956898, 0.34834102, -0.15371901, -0.14056298, 0.17291501,
0.73967993, -0.0257776 , -0.28438202, -0.337454 , 0.12431702,
0.063307 , -0.391515 , -0.24294749, 0.3378177 , 0.37893206,
0.14127994, 0.70388097, 0.021424 , 0.142003 , 0.20465 ,
-0.36599994, -0.14310999, -0.17243698, -0.00424001, 0.67148 ,
-0.17920549, 0.45753998, 0.17486003, -0.23000398, 0.06431001,
0.13716793, -0.172827 , -0.32512403, 0.22375101, -0.3474555 ,
0.447715 , 0.28867 , -0.14638105, -0.04995 , -0.437648 ,
-0.2236634 , -0.14245 , 0.03281999, -0.16247103, 0.5124899 ,
-0.40227997, -0.150479 , -0.38445002, 0.359772 , 0.30387995,
0.577236 , 0.534451 , 0.281598 , 0.126359 , -0.019406 ,
-0.26014996, -0.15996996, -0.15767002, 0.00154799, 0.195612 ,
-0.13352397, 0.01087999, -0.080301 , -0.20445602, -0.11846301,
-0.371925 , 0.39347702, 0.26368502, 0.392657 , 0.48374 ,
0.06531 , 0.068128 , 0.11742002, 0.04229499, 0.10026699,
0.30376 , 0.06063001, 0.3936985 , -0.10366529, 0.065814 ,
0.14065003, 0.17174399, -0.20236002, -0.55088 , -0.722872 ,
-0.48885 , -0.37717 , 0.07013199, -0.52826 , 0.096489 ,
0.5985999 , -0.13812901, -0.11418399, -0.190035 , 0.06799701,
0.02872499, 0.387542 , 0.00787 , -0.623389 , -0.09111011,
-0.22364 , -0.1886197 , -0.20119 , 0.22608899, -0.24934301,
0.08535001, -0.27039596, 0.30038005, -0.090203 , -0.14802799,
0.14603001, 0.21248001, 0.118833 , -0.07153228, -0.12797996,
-0.274443 , 0.30433598, 0.29837996, -0.01640302, 0.11600998,
-0.33268997, -0.056754 , 0.13773698, -0.188018 , -0.51105094,
-0.2561026 , -0.07734999, -0.457643 , 0.12696004, -0.25476858,
0.01485402, -0.27168003, -0.09315271, -0.18197 , 0.46563497,
0.34945 , 0.27662 , -0.138596 , 0.200928 , -0.34992003,
-0.48564997, -0.603999 , -0.181443 , -0.11616989, 0.129803 ,
0.02417099, 0.05545059, 0.117446 , -0.03544599, -0.57339 ,
0.44310898, 0.33150995, 0.01238599, -0.21157703, -0.03491596,
0.26410997, -0.22768001, -0.252998 , -0.23517999, 0.48754 ,
0.194835 , -0.27317 , -0.440702 , 0.367029 , 0.09925799,
-0.06908001, -0.14320281, 0.22666103, 0.2794511 , 0.29843 ,
0.21248499, -0.635843 , 0.20785001, 0.483295 , -0.47914696,
-0.03455502, 0.34644902, -0.37480602, -0.15627 , 0.12277907,
-0.04933499, 0.005468 , 0.00519997, -0.37172398, -0.175451 ,
-0.18385059, -0.21175501, -0.313944 , 0.07360198, -0.01590204,
-0.17416 , -0.00090003, 0.11262399, -0.48282 , -0.10517 ,
0.05565304, 0.32160503, -0.24056101, -0.30389994, -0.5073231 ,
0.33911803, -0.23648998, 0.06108901, 0.23029798, -0.02688998,
0.08346 , 0.17561206, 0.331848 , -0.09330803, 0.2918205 ,
0.277062 , -0.32242298, -0.002744 , 0.36982 , 0.51171 ,
-0.39322 , -0.16557002, -0.18774 , -0.01507998, -0.284651 ,
-0.07072806, -0.05853601, -0.06321001, -0.09849399, -0.09514015,
-0.23703995, -0.17931 , 0.38357297, 0.01018202, 0.10888296,
0.29964393, 0.12595999, 0.605805 , 0.04320699, 0.18856 ,
0.636185 , -0.18775499, 0.421264 , -0.15406296, -0.36692598,
0.094318 , 0.02511001, 0.06609299, -0.17440999, 0.00357999,
0.08757752, 0.04765201, 0.27466798, 0.7439101 , -0.01412702],
dtype=float32)
nlp = spacy.load('en_core_web_lg')
doc = nlp('queen king woman man')
queen, king, woman, man = doc[0].vector, doc[1].vector, doc[2].vector, doc[3].vector
vec = king - man + woman
vec - queen
array([ 0.10458702, -0.05152999, -0.01085299, 0.40603995, 0.111525 ,
0.03181005, -0.18277001, 0.10793996, 0.22586 , 0.42549992,
-0.620518 , 0.09305897, -0.0758817 , -0.29067168, -0.297841 ,
-0.43369 , -0.44859397, 0.21168 , -0.172735 , 0.24211 ,
0.20211 , -0.15502006, -0.04844499, -0.202636 , -0.21129996,
0.457768 , 0.03138995, 0.13294101, -0.534806 , -0.07134694,
-0.157518 , -0.05403006, -0.14246997, -0.773906 , 0.15866998,
-0.12601201, -0.19204 , -0.40347007, 0.05978 , 0.5203604 ,
0.37192 , -0.252379 , -0.097138 , -0.40504098, 0.25123 ,
-0.03785798, -0.11933102, -0.00672996, 0.40258 , 0.02721703,
-0.29956898, 0.34834102, -0.15371901, -0.14056298, 0.17291501,
0.73967993, -0.0257776 , -0.28438202, -0.337454 , 0.12431702,
0.063307 , -0.391515 , -0.24294749, 0.3378177 , 0.37893206,
0.14127994, 0.70388097, 0.021424 , 0.142003 , 0.20465 ,
-0.36599994, -0.14310999, -0.17243698, -0.00424001, 0.67148 ,
-0.17920549, 0.45753998, 0.17486003, -0.23000398, 0.06431001,
0.13716793, -0.172827 , -0.32512403, 0.22375101, -0.3474555 ,
0.447715 , 0.28867 , -0.14638105, -0.04995 , -0.437648 ,
-0.2236634 , -0.14245 , 0.03281999, -0.16247103, 0.5124899 ,
-0.40227997, -0.150479 , -0.38445002, 0.359772 , 0.30387995,
0.577236 , 0.534451 , 0.281598 , 0.126359 , -0.019406 ,
-0.26014996, -0.15996996, -0.15767002, 0.00154799, 0.195612 ,
-0.13352397, 0.01087999, -0.080301 , -0.20445602, -0.11846301,
-0.371925 , 0.39347702, 0.26368502, 0.392657 , 0.48374 ,
0.06531 , 0.068128 , 0.11742002, 0.04229499, 0.10026699,
0.30376 , 0.06063001, 0.3936985 , -0.10366529, 0.065814 ,
0.14065003, 0.17174399, -0.20236002, -0.55088 , -0.722872 ,
-0.48885 , -0.37717 , 0.07013199, -0.52826 , 0.096489 ,
0.5985999 , -0.13812901, -0.11418399, -0.190035 , 0.06799701,
0.02872499, 0.387542 , 0.00787 , -0.623389 , -0.09111011,
-0.22364 , -0.1886197 , -0.20119 , 0.22608899, -0.24934301,
0.08535001, -0.27039596, 0.30038005, -0.090203 , -0.14802799,
0.14603001, 0.21248001, 0.118833 , -0.07153228, -0.12797996,
-0.274443 , 0.30433598, 0.29837996, -0.01640302, 0.11600998,
-0.33268997, -0.056754 , 0.13773698, -0.188018 , -0.51105094,
-0.2561026 , -0.07734999, -0.457643 , 0.12696004, -0.25476858,
0.01485402, -0.27168003, -0.09315271, -0.18197 , 0.46563497,
0.34945 , 0.27662 , -0.138596 , 0.200928 , -0.34992003,
-0.48564997, -0.603999 , -0.181443 , -0.11616989, 0.129803 ,
0.02417099, 0.05545059, 0.117446 , -0.03544599, -0.57339 ,
0.44310898, 0.33150995, 0.01238599, -0.21157703, -0.03491596,
0.26410997, -0.22768001, -0.252998 , -0.23517999, 0.48754 ,
0.194835 , -0.27317 , -0.440702 , 0.367029 , 0.09925799,
-0.06908001, -0.14320281, 0.22666103, 0.2794511 , 0.29843 ,
0.21248499, -0.635843 , 0.20785001, 0.483295 , -0.47914696,
-0.03455502, 0.34644902, -0.37480602, -0.15627 , 0.12277907,
-0.04933499, 0.005468 , 0.00519997, -0.37172398, -0.175451 ,
-0.18385059, -0.21175501, -0.313944 , 0.07360198, -0.01590204,
-0.17416 , -0.00090003, 0.11262399, -0.48282 , -0.10517 ,
0.05565304, 0.32160503, -0.24056101, -0.30389994, -0.5073231 ,
0.33911803, -0.23648998, 0.06108901, 0.23029798, -0.02688998,
0.08346 , 0.17561206, 0.331848 , -0.09330803, 0.2918205 ,
0.277062 , -0.32242298, -0.002744 , 0.36982 , 0.51171 ,
-0.39322 , -0.16557002, -0.18774 , -0.01507998, -0.284651 ,
-0.07072806, -0.05853601, -0.06321001, -0.09849399, -0.09514015,
-0.23703995, -0.17931 , 0.38357297, 0.01018202, 0.10888296,
0.29964393, 0.12595999, 0.605805 , 0.04320699, 0.18856 ,
0.636185 , -0.18775499, 0.421264 , -0.15406296, -0.36692598,
0.094318 , 0.02511001, 0.06609299, -0.17440999, 0.00357999,
0.08757752, 0.04765201, 0.27466798, 0.7439101 , -0.01412702],
dtype=float32)