Skip to content

Commit 576652e

Browse files
reorganize the file structure
1 parent 1370df7 commit 576652e

File tree

8 files changed

+1704
-0
lines changed

8 files changed

+1704
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from posixpath import join
2+
import numpy
3+
from numpy.lib.npyio import save
4+
from script.data_iterator import DataIterator
5+
import tensorflow as tf
6+
import time
7+
import random
8+
import sys
9+
from script.utils import *
10+
from tensorflow.python.framework import ops
11+
import os
12+
import json
13+
14+
EMBEDDING_DIM = 18
15+
HIDDEN_SIZE = 18 * 2
16+
ATTENTION_SIZE = 18 * 2
17+
best_auc = 0.0
18+
best_case_acc = 0.0
19+
batch_size=1
20+
maxlen=100
21+
22+
data_location='../data'
23+
test_file = os.path.join(data_location, "local_test_splitByUser")
24+
uid_voc = os.path.join(data_location, "uid_voc.pkl")
25+
mid_voc = os.path.join(data_location, "mid_voc.pkl")
26+
cat_voc = os.path.join(data_location, "cat_voc.pkl")
27+
28+
def prepare_data(input, target, maxlen=None, return_neg=False):
29+
# x: a list of sentences
30+
lengths_x = [len(s[4]) for s in input]
31+
seqs_mid = [inp[3] for inp in input]
32+
seqs_cat = [inp[4] for inp in input]
33+
noclk_seqs_mid = [inp[5] for inp in input]
34+
noclk_seqs_cat = [inp[6] for inp in input]
35+
36+
if maxlen is not None:
37+
new_seqs_mid = []
38+
new_seqs_cat = []
39+
new_noclk_seqs_mid = []
40+
new_noclk_seqs_cat = []
41+
new_lengths_x = []
42+
for l_x, inp in zip(lengths_x, input):
43+
if l_x > maxlen:
44+
new_seqs_mid.append(inp[3][l_x - maxlen:])
45+
new_seqs_cat.append(inp[4][l_x - maxlen:])
46+
new_noclk_seqs_mid.append(inp[5][l_x - maxlen:])
47+
new_noclk_seqs_cat.append(inp[6][l_x - maxlen:])
48+
new_lengths_x.append(maxlen)
49+
else:
50+
new_seqs_mid.append(inp[3])
51+
new_seqs_cat.append(inp[4])
52+
new_noclk_seqs_mid.append(inp[5])
53+
new_noclk_seqs_cat.append(inp[6])
54+
new_lengths_x.append(l_x)
55+
lengths_x = new_lengths_x
56+
seqs_mid = new_seqs_mid
57+
seqs_cat = new_seqs_cat
58+
noclk_seqs_mid = new_noclk_seqs_mid
59+
noclk_seqs_cat = new_noclk_seqs_cat
60+
61+
if len(lengths_x) < 1:
62+
return None, None, None, None
63+
64+
n_samples = len(seqs_mid)
65+
maxlen_x = numpy.max(lengths_x)
66+
neg_samples = len(noclk_seqs_mid[0][0])
67+
68+
mid_his = numpy.zeros((n_samples, maxlen_x)).astype('int64')
69+
cat_his = numpy.zeros((n_samples, maxlen_x)).astype('int64')
70+
noclk_mid_his = numpy.zeros(
71+
(n_samples, maxlen_x, neg_samples)).astype('int64')
72+
noclk_cat_his = numpy.zeros(
73+
(n_samples, maxlen_x, neg_samples)).astype('int64')
74+
mid_mask = numpy.zeros((n_samples, maxlen_x)).astype('float32')
75+
for idx, [s_x, s_y, no_sx, no_sy] in enumerate(
76+
zip(seqs_mid, seqs_cat, noclk_seqs_mid, noclk_seqs_cat)):
77+
mid_mask[idx, :lengths_x[idx]] = 1.
78+
mid_his[idx, :lengths_x[idx]] = s_x
79+
cat_his[idx, :lengths_x[idx]] = s_y
80+
noclk_mid_his[idx, :lengths_x[idx], :] = no_sx
81+
noclk_cat_his[idx, :lengths_x[idx], :] = no_sy
82+
83+
uids = numpy.array([inp[0] for inp in input])
84+
mids = numpy.array([inp[1] for inp in input])
85+
cats = numpy.array([inp[2] for inp in input])
86+
87+
if return_neg:
88+
return uids, mids, cats, mid_his, cat_his, mid_mask, numpy.array(
89+
target), numpy.array(lengths_x), noclk_mid_his, noclk_cat_his
90+
91+
else:
92+
return uids, mids, cats, mid_his, cat_his, mid_mask, numpy.array(
93+
target), numpy.array(lengths_x)
94+
95+
96+
test_data = DataIterator(test_file,
97+
uid_voc,
98+
mid_voc,
99+
cat_voc,
100+
batch_size,
101+
maxlen,
102+
data_location=data_location)
103+
104+
f = open("./test_data.csv","w")
105+
counter = 0
106+
107+
for src, tgt in test_data:
108+
uids, mids, cats, mid_his, cat_his, mid_mask, target, sl = prepare_data(src, tgt)
109+
all_data = [uids, mids, cats, mid_his, cat_his, mid_mask, target, sl]
110+
for cur_data in all_data:
111+
cur_data = numpy.squeeze(cur_data).reshape(-1)
112+
for col in range(cur_data.shape[0]):
113+
uid = cur_data[col]
114+
# print(uid)
115+
if col == cur_data.shape[0]-1:
116+
f.write(str(uid)+",k,")
117+
break
118+
f.write(str(uid)+",")
119+
120+
f.write("\n");
121+
if counter >= 1:
122+
break
123+
counter += 1
124+
125+
f.close()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from posixpath import join
2+
import numpy
3+
from numpy.lib.npyio import save
4+
from script.data_iterator import DataIterator
5+
import tensorflow as tf
6+
from script.model import *
7+
import time
8+
import random
9+
import sys
10+
from script.utils import *
11+
from tensorflow.python.framework import ops
12+
from tensorflow.python.client import timeline
13+
import argparse
14+
import os
15+
import json
16+
import pickle as pkl
17+
18+
EMBEDDING_DIM = 18
19+
HIDDEN_SIZE = 18 * 2
20+
ATTENTION_SIZE = 18 * 2
21+
best_auc = 0.0
22+
best_case_acc = 0.0
23+
batch_size = 128
24+
25+
def unicode_to_utf8(d):
26+
return dict((key.encode("UTF-8"), value) for (key, value) in d.items())
27+
28+
29+
def load_dict(filename):
30+
try:
31+
with open(filename, 'rb') as f:
32+
return unicode_to_utf8(json.load(f))
33+
except:
34+
with open(filename, 'rb') as f:
35+
return pkl.load(f)
36+
37+
38+
def main(n_uid,n_mid,n_cat):
39+
40+
with tf.Session() as sess1:
41+
42+
model = Model_DIN_V2_Gru_Vec_attGru_Neg(n_uid, n_mid, n_cat,
43+
EMBEDDING_DIM, HIDDEN_SIZE,
44+
ATTENTION_SIZE)
45+
46+
# Initialize saver
47+
folder_dir = args.checkpoint
48+
saver = tf.train.Saver()
49+
50+
sess1.run(tf.global_variables_initializer())
51+
sess1.run(tf.local_variables_initializer())
52+
# Restore from checkpoint
53+
saver.restore(sess1,tf.train.latest_checkpoint(folder_dir))
54+
55+
# Get save directory
56+
dir = "./savedmodels"
57+
os.makedirs(dir,exist_ok=True)
58+
cc_time = int(time.time())
59+
saved_path = os.path.join(dir,str(cc_time))
60+
os.mkdir(saved_path)
61+
62+
63+
tf.saved_model.simple_save(
64+
sess1,
65+
saved_path,
66+
inputs = {"Inputs/mid_his_batch_ph:0":model.mid_his_batch_ph,"Inputs/cat_his_batch_ph:0":model.cat_his_batch_ph,
67+
"Inputs/uid_batch_ph:0":model.uid_batch_ph,"Inputs/mid_batch_ph:0":model.mid_batch_ph,"Inputs/cat_batch_ph:0":model.cat_batch_ph,
68+
"Inputs/mask:0":model.mask,"Inputs/seq_len_ph:0":model.seq_len_ph,"Inputs/target_ph:0":model.target_ph},
69+
outputs = {"top_full_connect/add_2:0":model.y_hat}
70+
)
71+
72+
73+
74+
if __name__ == '__main__':
75+
parser = argparse.ArgumentParser()
76+
77+
parser.add_argument('--checkpoint',
78+
help='ckpt path',
79+
required=False,
80+
default='../data')
81+
parser.add_argument('--bf16',
82+
help='enable DeepRec BF16 in deep model. Default FP32',
83+
action='store_true')
84+
parser.add_argument('--data_location',
85+
help='Full path of train data',
86+
required=False,
87+
default='./data')
88+
args = parser.parse_args()
89+
90+
uid_voc = os.path.join(args.data_location, "uid_voc.pkl")
91+
mid_voc = os.path.join(args.data_location, "mid_voc.pkl")
92+
cat_voc = os.path.join(args.data_location, "cat_voc.pkl")
93+
94+
uid_d = load_dict(uid_voc)
95+
mid_d = load_dict(mid_voc)
96+
cat_d = load_dict(cat_voc)
97+
98+
main(len(uid_d),len(mid_d),len(cat_d))
99+
100+

0 commit comments

Comments
 (0)