Skip to content

Commit ad80092

Browse files
authored
Refactors functions in performance.py (#2)
* Refactors functions in performance.py - accuracy has now keyword-only arguments. - predict_df has now keyword-only arguments. - `accuracy (*, pred, gold, method='correlation')` (old: `accuracy (hat, mat, distance=False)`) - `predict_df (*, pred, gold, n=1, method='correlation')` (old: `predict_df (hat, mat, max_guess=1, distance=False, method='cosine')`) - Adds docstrings to `predict_df`. * Increments the version number & cleans pyproject.toml * Fixes a small bug in performance.py; fixes tests
1 parent aa4a39e commit ad80092

14 files changed

+102
-125
lines changed

discriminative_lexicon_model/performance.py

Lines changed: 63 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,40 +3,72 @@
33
import xarray as xr
44
import scipy.spatial.distance as spd
55

6-
def accuracy (hat, mat, distance=False):
7-
pred = predict_df(hat, mat, max_guess=1, distance=distance)
8-
acc = pred.acc.sum() / len(pred)
6+
def accuracy (*, pred, gold, method='correlation'):
7+
pred = predict_df(pred=pred, gold=gold, n=1, method=method)
8+
acc = pred.Correct.sum() / len(pred)
99
return acc
1010

11-
def predict_df (hat, mat, max_guess=1, distance=False, method='cosine'):
12-
if not isinstance(max_guess, int): raise TypeError('"max_guess" must be integer')
13-
coss = distance_matrix(pred=hat, gold=mat, method=method).values
14-
if distance:
15-
pos1 = [np.argmin(coss, axis=1)]
16-
sign = 1
17-
else:
18-
coss = 1 - coss
19-
pos1 = [np.argmax(coss, axis=1)]
20-
sign = -1
11+
def predict_df (*, pred, gold, n=1, method='correlation'):
12+
"""
13+
Constructs a dataframe of predictions.
2114
22-
if max_guess>1:
23-
pos = [ np.apply_along_axis(lambda x: np.argsort(x)[(sign*i)], 1, coss) for i in range(2,max_guess+1) ]
24-
else:
25-
pos = []
26-
pos = pos1 + pos
27-
prds = [ [ mat.word.values[j] for j in i ] for i in pos ]
28-
hits = [ [ j==k for j,k in zip(i,hat.word.values) ] for i in prds ]
29-
if len(prds)==1:
30-
prds = [ pd.DataFrame({'pred':j}) for j in prds ]
31-
hits = [ pd.DataFrame({'acc':j}) for j in hits ]
32-
else:
33-
prds = [ pd.DataFrame({'pred{:d}'.format(i+1):j}) for i,j in enumerate(prds) ]
34-
hits = [ pd.DataFrame({'acc{:d}'.format(i+1):j}) for i,j in enumerate(hits) ]
35-
prds = pd.concat(prds, axis=1)
36-
hits = pd.concat(hits, axis=1)
37-
wrds = pd.DataFrame({'Word':hat.word.values})
38-
dddd = pd.concat([wrds,prds,hits], axis=1)
39-
return dddd
15+
Parameters
16+
----------
17+
pred : xarray.core.dataarray.DataArray
18+
A matrix of predictions. It is usually a C-hat or S-hat matrix.
19+
gold : xarray.core.dataarray.DataArray
20+
A matrix of gold-standard vectors. It is usually a C or S matrix.
21+
n : int or None
22+
The number of predictions to make for each word. When n=1, the first prediction for each word will be produced. When n=2, the first and second predictions for each word will be included in the output dataframe. When n=None, as many predictions as possible will be produced.
23+
method : str
24+
Which method to use to calculate distance/similarity. It must be "correlation", "cosine" (for cosine similarity), and "euclidean" (for euclidean distance).
25+
26+
Returns
27+
-------
28+
df : pandas.core.frame.DataFrame
29+
A dataframe of a model's predictions.
30+
31+
Examples
32+
--------
33+
>>> import discriminative_lexicon_model as dlm
34+
>>> import pandas as pd
35+
>>> words = ['cat','rat','hat']
36+
>>> sems = pd.DataFrame({'<animate>':[1,1,0], '<object>':[0,0,1], '<predator>':[1,0,0]}, index=words)
37+
>>> mdl = dlm.ldl.LDL()
38+
>>> mdl.gen_cmat(words)
39+
>>> mdl.gen_smat(sems)
40+
>>> mdl.gen_gmat()
41+
>>> mdl.gen_chat()
42+
>>> dlm.performance.predict_df(pred=mdl.chat, gold=mdl.cmat, n=2, method='correlation')
43+
Word Pred1 Pred2 Correct1 Correct2
44+
0 cat cat hat True False
45+
1 rat rat hat True False
46+
2 hat hat cat True False
47+
"""
48+
if not (method in ['correlation', 'cosine', 'euclidean']):
49+
raise ValueError('"method" must be "correlation", "cosine", or "euclidean".')
50+
if not (n is None):
51+
if not isinstance(n, int):
52+
raise TypeError('"n" must be integer or None.')
53+
if not (n>0):
54+
raise ValueError('"n" must be a positive integer.')
55+
n = pred.shape[0] if n is None else n
56+
57+
dist = distance_matrix(pred=pred, gold=gold, method=method).values
58+
dist = dist if method=='euclidean' else 1-dist
59+
inds = dist.argsort(axis=1) if method=='euclidean' else (-dist).argsort(axis=1)
60+
inds = inds[:,:n]
61+
62+
prds = np.apply_along_axis(lambda x: gold.word.values[x], 1, inds)
63+
hits = np.array([ prds[i,:]==j for i,j in zip(range(prds.shape[0]), gold.word.values) ])
64+
65+
clms = ['Pred'] if prds.shape[1]==1 else [ 'Pred{:d}'.format(i) for i in range(1, prds.shape[1]+1) ]
66+
prds = pd.DataFrame(prds, columns=clms)
67+
clms = ['Correct'] if hits.shape[1]==1 else [ 'Correct{:d}'.format(i) for i in range(1, hits.shape[1]+1) ]
68+
hits = pd.DataFrame(hits, columns=clms)
69+
wrds = pd.DataFrame({'Word':gold.word.values})
70+
df = pd.concat([wrds, prds, hits], axis=1)
71+
return df
4072

4173
def distance_matrix (*, pred, gold, method='cosine'):
4274
"""
@@ -70,15 +102,3 @@ def distance_matrix (*, pred, gold, method='cosine'):
70102
dist = xr.DataArray(dist, dims=('pred','gold'), coords=new_coords)
71103
return dist
72104

73-
def predict (word, hat, mat, distance=False):
74-
hat = np.tile(hat.loc[word,:], (1,1))
75-
coss = spd.cdist(np.array(hat), np.array(mat), 'cosine')
76-
if distance:
77-
sign = 1
78-
else:
79-
coss = 1 - coss
80-
sign = -1
81-
coss = coss[0,:]
82-
pred = mat.word.values[np.argsort(sign*coss)]
83-
return pd.Series(pred)
84-

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "discriminative_lexicon_model"
3-
version = "1.4.3"
3+
version = "2.0,0"
44
description = "Python-implementation of Discriminative Lexicon Model / Linear Discriminative Learning"
55

66
license = "MIT"
@@ -40,5 +40,5 @@ sphinx = ">=7.3"
4040
sphinx_rtd_theme = ">=2.0"
4141

4242
[build-system]
43-
requires = ["poetry-core", "setuptools", "Cython", "numpy"]
43+
requires = ["poetry-core"]
4444
build-backend = "poetry.core.masonry.api"

tests/resources/predict_df_00.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
Word pred acc
1+
Word Pred Correct
22
walk0 walk0 True
33
walk1 walk0 False
44
walks walks True

tests/resources/predict_df_01.csv

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
Word pred acc
2-
walk0 walk0 True
3-
walk1 walk0 False
4-
walks walks True
5-
walked0 walked0 True
6-
walked1 walked0 False
7-
walked2 walked0 False
1+
Word Pred1 Pred2 Correct1 Correct2
2+
walk0 walk0 walk1 True False
3+
walk1 walk0 walk1 False True
4+
walks walks walk0 True False
5+
walked0 walked0 walked1 True False
6+
walked1 walked0 walked1 False True
7+
walked2 walked0 walked1 False False

tests/resources/predict_df_02.csv

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
Word pred1 pred2 acc1 acc2
2-
walk0 walk0 walks True False
3-
walk1 walk0 walks False False
4-
walks walks walk1 True False
5-
walked0 walked0 walked2 True False
6-
walked1 walked0 walked2 False False
7-
walked2 walked0 walked2 False True
1+
Word Pred Correct
2+
walk0 walk1 False
3+
walk1 walk1 True
4+
walks walks True
5+
walked0 walked2 False
6+
walked1 walked2 False
7+
walked2 walked2 True

tests/resources/predict_df_03.csv

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
Word pred1 pred2 acc1 acc2
2-
walk0 walk0 walk0 True True
3-
walk1 walk0 walk0 False False
4-
walks walks walk1 True False
5-
walked0 walked0 walked1 True False
6-
walked1 walked0 walked1 False True
7-
walked2 walked0 walked1 False False
1+
Word Pred1 Pred2 Correct1 Correct2
2+
walk0 walk1 walk0 False True
3+
walk1 walk1 walk0 True False
4+
walks walks walked2 True False
5+
walked0 walked2 walked0 False True
6+
walked1 walked2 walked0 False False
7+
walked2 walked2 walked0 True False

tests/resources/predict_df_04.csv

Lines changed: 0 additions & 7 deletions
This file was deleted.

tests/resources/predict_df_05.csv

Lines changed: 0 additions & 7 deletions
This file was deleted.

tests/resources/predict_df_06.csv

Lines changed: 0 additions & 7 deletions
This file was deleted.

tests/resources/predict_df_07.csv

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)