1
+ from posixpath import join
2
+ import numpy
3
+ from numpy .lib .npyio import save
4
+ from script .data_iterator import DataIterator
5
+ import tensorflow as tf
6
+ import time
7
+ import random
8
+ import sys
9
+ from script .utils import *
10
+ from tensorflow .python .framework import ops
11
+ import os
12
+ import json
13
+
14
+ EMBEDDING_DIM = 18
15
+ HIDDEN_SIZE = 18 * 2
16
+ ATTENTION_SIZE = 18 * 2
17
+ best_auc = 0.0
18
+ best_case_acc = 0.0
19
+ batch_size = 1
20
+ maxlen = 100
21
+
22
+ data_location = '../data'
23
+ test_file = os .path .join (data_location , "local_test_splitByUser" )
24
+ uid_voc = os .path .join (data_location , "uid_voc.pkl" )
25
+ mid_voc = os .path .join (data_location , "mid_voc.pkl" )
26
+ cat_voc = os .path .join (data_location , "cat_voc.pkl" )
27
+
28
+ def prepare_data (input , target , maxlen = None , return_neg = False ):
29
+ # x: a list of sentences
30
+ lengths_x = [len (s [4 ]) for s in input ]
31
+ seqs_mid = [inp [3 ] for inp in input ]
32
+ seqs_cat = [inp [4 ] for inp in input ]
33
+ noclk_seqs_mid = [inp [5 ] for inp in input ]
34
+ noclk_seqs_cat = [inp [6 ] for inp in input ]
35
+
36
+ if maxlen is not None :
37
+ new_seqs_mid = []
38
+ new_seqs_cat = []
39
+ new_noclk_seqs_mid = []
40
+ new_noclk_seqs_cat = []
41
+ new_lengths_x = []
42
+ for l_x , inp in zip (lengths_x , input ):
43
+ if l_x > maxlen :
44
+ new_seqs_mid .append (inp [3 ][l_x - maxlen :])
45
+ new_seqs_cat .append (inp [4 ][l_x - maxlen :])
46
+ new_noclk_seqs_mid .append (inp [5 ][l_x - maxlen :])
47
+ new_noclk_seqs_cat .append (inp [6 ][l_x - maxlen :])
48
+ new_lengths_x .append (maxlen )
49
+ else :
50
+ new_seqs_mid .append (inp [3 ])
51
+ new_seqs_cat .append (inp [4 ])
52
+ new_noclk_seqs_mid .append (inp [5 ])
53
+ new_noclk_seqs_cat .append (inp [6 ])
54
+ new_lengths_x .append (l_x )
55
+ lengths_x = new_lengths_x
56
+ seqs_mid = new_seqs_mid
57
+ seqs_cat = new_seqs_cat
58
+ noclk_seqs_mid = new_noclk_seqs_mid
59
+ noclk_seqs_cat = new_noclk_seqs_cat
60
+
61
+ if len (lengths_x ) < 1 :
62
+ return None , None , None , None
63
+
64
+ n_samples = len (seqs_mid )
65
+ maxlen_x = numpy .max (lengths_x )
66
+ neg_samples = len (noclk_seqs_mid [0 ][0 ])
67
+
68
+ mid_his = numpy .zeros ((n_samples , maxlen_x )).astype ('int64' )
69
+ cat_his = numpy .zeros ((n_samples , maxlen_x )).astype ('int64' )
70
+ noclk_mid_his = numpy .zeros (
71
+ (n_samples , maxlen_x , neg_samples )).astype ('int64' )
72
+ noclk_cat_his = numpy .zeros (
73
+ (n_samples , maxlen_x , neg_samples )).astype ('int64' )
74
+ mid_mask = numpy .zeros ((n_samples , maxlen_x )).astype ('float32' )
75
+ for idx , [s_x , s_y , no_sx , no_sy ] in enumerate (
76
+ zip (seqs_mid , seqs_cat , noclk_seqs_mid , noclk_seqs_cat )):
77
+ mid_mask [idx , :lengths_x [idx ]] = 1.
78
+ mid_his [idx , :lengths_x [idx ]] = s_x
79
+ cat_his [idx , :lengths_x [idx ]] = s_y
80
+ noclk_mid_his [idx , :lengths_x [idx ], :] = no_sx
81
+ noclk_cat_his [idx , :lengths_x [idx ], :] = no_sy
82
+
83
+ uids = numpy .array ([inp [0 ] for inp in input ])
84
+ mids = numpy .array ([inp [1 ] for inp in input ])
85
+ cats = numpy .array ([inp [2 ] for inp in input ])
86
+
87
+ if return_neg :
88
+ return uids , mids , cats , mid_his , cat_his , mid_mask , numpy .array (
89
+ target ), numpy .array (lengths_x ), noclk_mid_his , noclk_cat_his
90
+
91
+ else :
92
+ return uids , mids , cats , mid_his , cat_his , mid_mask , numpy .array (
93
+ target ), numpy .array (lengths_x )
94
+
95
+
96
+ test_data = DataIterator (test_file ,
97
+ uid_voc ,
98
+ mid_voc ,
99
+ cat_voc ,
100
+ batch_size ,
101
+ maxlen ,
102
+ data_location = data_location )
103
+
104
+ f = open ("./test_data.csv" ,"w" )
105
+ counter = 0
106
+
107
+ for src , tgt in test_data :
108
+ uids , mids , cats , mid_his , cat_his , mid_mask , target , sl = prepare_data (src , tgt )
109
+ all_data = [uids , mids , cats , mid_his , cat_his , mid_mask , target , sl ]
110
+ for cur_data in all_data :
111
+ cur_data = numpy .squeeze (cur_data ).reshape (- 1 )
112
+ for col in range (cur_data .shape [0 ]):
113
+ uid = cur_data [col ]
114
+ # print(uid)
115
+ if col == cur_data .shape [0 ]- 1 :
116
+ f .write (str (uid )+ ",k," )
117
+ break
118
+ f .write (str (uid )+ "," )
119
+
120
+ f .write ("\n " );
121
+ if counter >= 1 :
122
+ break
123
+ counter += 1
124
+
125
+ f .close ()
0 commit comments