Skip to content

Commit ba51712

Browse files
committed
Adds dlm.performance.distance_matrix
1 parent f828921 commit ba51712

File tree

1 file changed

+37
-3
lines changed

1 file changed

+37
-3
lines changed

discriminative_lexicon_model/performance.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
import numpy as np
22
import pandas as pd
3+
import xarray as xr
34
import scipy.spatial.distance as spd
45

56
def accuracy (hat, mat, distance=False):
67
pred = predict_df(hat, mat, max_guess=1, distance=distance)
78
acc = pred.acc.sum() / len(pred)
89
return acc
910

10-
def predict_df (hat, mat, max_guess=1, distance=False):
11-
coss = spd.cdist(np.array(hat), np.array(mat), 'cosine')
11+
def predict_df (hat, mat, max_guess=1, distance=False, method='cosine'):
12+
if not isinstance(max_guess, int): raise TypeError('"max_guess" must be integer')
13+
coss = distance_matrix(pred=hat, gold=mat, method=method).values
1214
if distance:
1315
pos1 = [np.argmin(coss, axis=1)]
1416
sign = 1
1517
else:
1618
coss = 1 - coss
1719
pos1 = [np.argmax(coss, axis=1)]
1820
sign = -1
19-
assert isinstance(max_guess, int)
21+
2022
if max_guess>1:
2123
pos = [ np.apply_along_axis(lambda x: np.argsort(x)[(sign*i)], 1, coss) for i in range(2,max_guess+1) ]
2224
else:
@@ -36,6 +38,38 @@ def predict_df (hat, mat, max_guess=1, distance=False):
3638
dddd = pd.concat([wrds,prds,hits], axis=1)
3739
return dddd
3840

41+
def distance_matrix (*, pred, gold, method='cosine'):
42+
"""
43+
Constructs a distance matrix between a matrix of predictions and that of
44+
gold-standards. If similarity is of more interest than distance (e.g.,
45+
correlation / cosine similarity), subtract the return value of this
46+
function from 1.
47+
48+
Parameters
49+
----------
50+
pred : xarray.core.dataarray.DataArray
51+
A prediction matrix, which is usually either a C-hat matrix or a S-hat
52+
matrix.
53+
gold : xarray.core.dataarray.DataArray
54+
A gold-standard matrix, which is usually either a C matrix or a S
55+
matrix.
56+
57+
Returns
58+
-------
59+
dist : xarray.core.dataarray.DataArray
60+
A 2-d array of the shape m x n, where m represents the number of rows
61+
in "pred" and n represents the number of rows in "gold". The cell value
62+
of the i-th row and the j-th column is the distance between the vector
63+
of the i-th row of "pred" and the vector of the j-th row of "gold". If
64+
similarity (e.g., correlation / cosine similarity) is of more interest
65+
than distance, subtract "dist" from 1 (i.e., 1 - dist).
66+
"""
67+
dist = spd.cdist(pred.values, gold.values, method)
68+
new_coords = {'pred':pred[pred.dims[0]].values,
69+
'gold':gold[gold.dims[0]].values}
70+
dist = xr.DataArray(dist, dims=('pred','gold'), coords=new_coords)
71+
return dist
72+
3973
def predict (word, hat, mat, distance=False):
4074
hat = np.tile(hat.loc[word,:], (1,1))
4175
coss = spd.cdist(np.array(hat), np.array(mat), 'cosine')

0 commit comments

Comments
 (0)