Skip to content

Commit 77678f9

Browse files
kuchenrollederNarr
authored andcommitted
Fixes duplicate outcomes when removes_duplicates=True (#148)
* enforces unique outcomes in preprocess.py; bumps version number * adds test for remove_duplicates in test_preprocess.py
1 parent 9b47606 commit 77678f9

File tree

7 files changed

+66
-20
lines changed

7 files changed

+66
-20
lines changed

pyndl/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
__author__ = ('Konstantin Sering, Marc Weitz, '
1818
'David-Elias Künstle, Lennard Schneider')
1919
__author_email__ = '[email protected]'
20-
__version__ = '0.5.0'
20+
__version__ = '0.5.1'
2121
__license__ = 'MIT'
2222
__description__ = ('Naive discriminative learning implements learning and '
2323
'classification models based on the Rescorla-Wagner '

pyndl/preprocess.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ def ngrams_to_word(occurrences, n_chars, outfile, remove_duplicates=True):
100100
if not ngrams or not occurrence:
101101
continue
102102
if remove_duplicates:
103-
outfile.write("{}\t{}\n".format("_".join(set(ngrams)), occurrence))
104-
else:
105-
outfile.write("{}\t{}\n".format("_".join(ngrams), occurrence))
103+
ngrams = set(ngrams)
104+
occurrence = "_".join(set(occurrence.split("_")))
105+
outfile.write("{}\t{}\n".format("_".join(ngrams), occurrence))
106106

107107

108108
def process_occurrences(occurrences, outfile, *,
@@ -132,9 +132,9 @@ def process_occurrences(occurrences, outfile, *,
132132
if not cues:
133133
continue
134134
if remove_duplicates:
135-
outfile.write("{}\t{}\n".format("_".join(set(cues.split("_"))), outcomes))
136-
else:
137-
outfile.write("{}\t{}\n".format(cues, outcomes))
135+
cues = "_".join(set(cues.split("_")))
136+
outcomes = "_".join(set(outcomes.split("_")))
137+
outfile.write("{}\t{}\n".format(cues, outcomes))
138138
else:
139139
raise NotImplementedError('cue_structure=%s is not implemented yet.' % cue_structure)
140140

@@ -245,19 +245,16 @@ def gen_occurrences(words):
245245
"""
246246
if event_structure == 'consecutive_words':
247247
occurrences = list()
248-
cur_words = list()
249-
ii = 0
250-
while True:
251-
if ii < len(words):
252-
cur_words.append(words[ii])
253-
if ii >= len(words) or ii >= number_of_words:
254-
# remove the first word
255-
cur_words = cur_words[1:]
248+
# can't have more consecutive words than total words
249+
length = min(number_of_words, len(words))
250+
# slide window over list of words
251+
for ii in range(1 - length, len(words)):
252+
# no consecutive words before first word
253+
start = max(ii, 0)
254+
# no consecutive words after last word
255+
end = min(ii + length, len(words))
256256
# append (cues, outcomes) with empty outcomes
257-
occurrences.append(("_".join(cur_words), ''))
258-
ii += 1
259-
if not cur_words:
260-
break
257+
occurrences.append(("_".join(words[start:end]), ""))
261258
return occurrences
262259
# for words = (A, B, C, D); before = 2, after = 1
263260
# make: (B, A), (A_C, B), (A_B_D, C), (B_C, D)
@@ -274,6 +271,8 @@ def gen_occurrences(words):
274271
elif event_structure == 'line':
275272
# (cues, outcomes) with empty outcomes
276273
return [('_'.join(words), ''), ]
274+
else:
275+
raise ValueError('gen_occurrences should be one of {"consecutive_words", "word_to_word", "line"}')
277276

278277
def process_line(line):
279278
"""processes one line of text."""

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ max-line-length = 120
44

55
[pylint]
66
max-line-length = 120
7-
good-names = nn, ii, _
7+
good-names = nn, ii, _, jj
88
extension-pkg-whitelist=numpy,pyndl.ndl_parallel
99
ignore=pyndl/ndl_parallel
1010
disable=E1101
-502 Bytes
Binary file not shown.
-461 Bytes
Binary file not shown.
-2.08 KB
Binary file not shown.

tests/test_preprocess.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,53 @@ def test_bigrams_to_word():
109109
os.remove(event_file)
110110

111111

112+
def test_remove_duplicates():
113+
event_file_noduplicates = os.path.join(TEST_ROOT, "temp/event_file_bigrams_to_word_noduplicates.tab.gz")
114+
event_file_duplicates = os.path.join(TEST_ROOT, "temp/event_file_bigrams_to_word_duplicates.tab.gz")
115+
create_event_file(RESOURCE_FILE, event_file_duplicates,
116+
context_structure="document",
117+
event_structure="consecutive_words",
118+
event_options=(3, ),
119+
cue_structure="bigrams_to_word",
120+
remove_duplicates=False)
121+
create_event_file(RESOURCE_FILE, event_file_noduplicates,
122+
context_structure="document",
123+
event_structure="consecutive_words",
124+
event_options=(3, ),
125+
cue_structure="bigrams_to_word",
126+
remove_duplicates=True)
127+
128+
with gzip.open(event_file_noduplicates, "rt") as new_file:
129+
lines_new = new_file.readlines()
130+
with gzip.open(event_file_duplicates, "rt") as reference:
131+
lines_reference = reference.readlines()
132+
assert len(lines_new) == len(lines_reference)
133+
n_cues_unequal = 0
134+
n_outcomes_unequal = 0
135+
for ii, line in enumerate(lines_new):
136+
cues, outcomes = line.strip().split('\t')
137+
cues = sorted(cues.split('_'))
138+
outcomes = sorted(outcomes.split('_'))
139+
ref_cues, ref_outcomes = lines_reference[ii].strip().split('\t')
140+
ref_cues = sorted(ref_cues.split('_'))
141+
ref_outcomes = sorted(ref_outcomes.split('_'))
142+
if len(cues) != len(ref_cues):
143+
n_cues_unequal += 1
144+
if len(outcomes) != len(ref_outcomes):
145+
n_outcomes_unequal += 1
146+
# there should be no duplicates in (noduplicates)
147+
assert len(cues) == len(set(cues))
148+
assert len(outcomes) == len(set(outcomes))
149+
# after making each list unique it should be the same
150+
assert set(cues) == set(ref_cues)
151+
assert set(outcomes) == set(ref_outcomes)
152+
assert n_cues_unequal == 1098
153+
assert n_outcomes_unequal == 66
154+
155+
os.remove(event_file_noduplicates)
156+
os.remove(event_file_duplicates)
157+
158+
112159
def test_word_to_word():
113160
event_file = os.path.join(TEST_ROOT, "temp/event_file_word_to_word.tab.gz")
114161
reference_file = os.path.join(TEST_ROOT, "reference/event_file_word_to_word.tab.gz")

0 commit comments

Comments
 (0)