Skip to content

Commit beee173

Browse files
committed
add S
1 parent 6ef55d6 commit beee173

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

src/find_path.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,12 @@ function learn_paths_rpi(
873873

874874
n = size(res)
875875
ngrams_ind = make_ngrams_ind(res, n)
876+
Shat = zeros(Float64, size(S_val))
877+
878+
for i in 1:n[1]
879+
ci = ngrams_ind[i]
880+
Shat[i,:] = sum(F_train[ci, :], dims = 1)
881+
end
876882

877883
tmp, rpi = learn_paths(
878884
data_train,
@@ -885,7 +891,7 @@ function learn_paths_rpi(
885891
i2f,
886892
f2i,
887893
gold_ind = ngrams_ind,
888-
Shat_val = nothing,
894+
Shat_val = Shat,
889895
check_gold_path = true,
890896
max_t = max_t,
891897
max_can = 1,
@@ -914,6 +920,7 @@ function learn_paths_rpi(
914920
else
915921
return res, rpi
916922
end
923+
917924
end
918925

919926

0 commit comments

Comments
 (0)