Skip to content

Commit 7e0c015

Browse files
authored
Module name changes in documentation and adds dlm.ldl.LDL.accuaracy (#8)
* Fixes tests * Adds the 'count' argument to dlm.ldl.LDL.gen_cmat. * Adds the dependency on netcdf4 in pyproject.toml * Adds a new argument 'mats' to dlm.ldl.LDL.save_matrices for saving matices selectively * Fixes the generation of C-hat and S-hat in docs (quickstart.rst) * Fixes pyldl to discriminative_lexicon_model in docs * Adds a docstring to dlm.performance.accuracy * Adds dlm.ldl.LDL.accuracy * Updates .gitignore to exclude notes/
1 parent c7cdc5e commit 7e0c015

File tree

3 files changed

+68
-0
lines changed

3 files changed

+68
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ discriminative_lexicon_model.egg-info/
33
*.swp
44
build/
55
dist/
6+
notes/

discriminative_lexicon_model/ldl.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33

44
from . import mapping as lm
5+
from . import performance as lp
56

67
class LDL:
78
def __init__ (self, words=None, embed_or_df=None, cmat=False, smat=False,
@@ -136,6 +137,34 @@ def load_matrices (self, directory, add=''):
136137
setattr(self, i, mat)
137138
return None
138139

140+
def accuracy (self, method='correlation', print_output=True):
141+
acc_comp = acc_prod = None
142+
exist_chat = hasattr(self, 'chat')
143+
exist_shat = hasattr(self, 'shat')
144+
if exist_chat:
145+
acc_prod = lp.accuracy(pred=self.chat, gold=self.cmat, method=method)
146+
if exist_shat:
147+
acc_comp = lp.accuracy(pred=self.shat, gold=self.smat, method=method)
148+
if (acc_comp is None) and (acc_prod is None):
149+
raise ValueError('No C-hat or S-hat was found.')
150+
if print_output:
151+
if (acc_comp is None) and (not acc_prod is None):
152+
acc_prod = 'Production: {}'.format(acc_prod)
153+
acc = acc_prod
154+
elif (not acc_comp is None) and (acc_prod is None):
155+
acc_comp = 'Comprehension: {}'.format(acc_comp)
156+
acc = acc_comp
157+
else:
158+
acc_prod = 'Production: {}'.format(acc_prod)
159+
acc_comp = 'Comprehension: {}'.format(acc_comp)
160+
acc = acc_comp + '\n' + acc_prod
161+
print(acc)
162+
acc = None
163+
else:
164+
acc = {'Comprehension': acc_comp, 'Production': acc_prod}
165+
acc = { i:j for i,j in acc.items() if not j is None }
166+
return acc
167+
139168
def concat_cues (a):
140169
assert is_consecutive(a)
141170
a = pd.Series(a).str.slice(start=0, stop=1).iloc[:-1].str.cat(sep='') + pd.Series(a).iloc[-1]

discriminative_lexicon_model/performance.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,44 @@
44
import scipy.spatial.distance as spd
55

66
def accuracy (*, pred, gold, method='correlation'):
7+
"""
8+
Calculates prediction accuracy from a matrix of predictions and that of
9+
gold-standard vectors. The prediction is considered as "correct" when its
10+
corresponding gold-standard vector is the most strongly correlated with the
11+
predicted vecor.
12+
13+
Parameters
14+
----------
15+
pred : xarray.core.dataarray.DataArray
16+
A matrix of predictions. It is usually a C-hat or S-hat matrix.
17+
gold : xarray.core.dataarray.DataArray
18+
A matrix of gold-standard vectors. It is usually a C or S matrix.
19+
method : str
20+
Which method to use to calculate distance/similarity. It must be
21+
"correlation", "cosine" (for cosine similarity), and "euclidean" (for
22+
euclidean distance).
23+
24+
Returns
25+
-------
26+
n : float
27+
The accuracy of the predictions, namely the ratio of words that are
28+
predicted correctly.to the total number of the words.
29+
30+
Examples
31+
--------
32+
>>> import discriminative_lexicon_model as dlm
33+
>>> import pandas as pd
34+
>>> words = ['cat','rat','hat']
35+
>>> sems = pd.DataFrame({'<animate>':[1,1,0], '<object>':[0,0,1], '<predator>':[1,0,0]}, index=words)
36+
>>> mdl = dlm.ldl.LDL()
37+
>>> mdl.gen_cmat(words)
38+
>>> mdl.gen_smat(sems)
39+
>>> mdl.gen_gmat()
40+
>>> mdl.gen_chat()
41+
>>> print(dlm.performance.accuracy(pred=mdl.chat, gold=mdl.cmat, method='correlation'))
42+
1.0
43+
"""
44+
745
pred = predict_df(pred=pred, gold=gold, n=1, method=method)
846
acc = pred.Correct.sum() / len(pred)
947
return acc

0 commit comments

Comments
 (0)