• 将libFM模型变换成tensorflow可serving的形式


    fm_model是libFM生成的模型

    model.ckpt是可以tensorflow serving的模型结构

    亲测输出正确。

    代码:

     1 import tensorflow as tf
     2 
     3 # libFM model
     4 def load_fm_model(file_name):
     5     state = ''
     6     fid = 0
     7     max_fid = 0
     8     w0 = 0.0
     9     wj = {}
    10     v = {}
    11     k = 0
    12     with open(file_name) as f:
    13         for line in f:
    14             line = line.rstrip()
    15             if 'global bias W0' in line:
    16                 state = 'w0'
    17                 fid = 0
    18                 continue
    19             elif 'unary interactions Wj' in line:
    20                 state = 'wj'
    21                 fid = 0
    22                 continue
    23             elif 'pairwise interactions Vj,f' in line:
    24                 state = 'v'
    25                 fid = 0
    26                 continue
    27 
    28             if state == 'w0':
    29                 fv = float(line)
    30                 w0 = fv
    31             elif state == 'wj':
    32                 fv = float(line)
    33                 if fv != 0:
    34                     wj[fid] = fv
    35                 fid += 1
    36                 max_fid = max(max_fid, fid)
    37             elif state == 'v':
    38                 fv = [float(_v) for _v in line.split(' ')]
    39                 k = len(fv)
    40                 if any([_v!=0 for _v in fv]):
    41                     v[fid] = fv
    42                 fid += 1
    43                 max_fid = max(max_fid, fid)
    44     return w0, wj, v, k, max_fid
    45 
    46 _w0, _wj, _v, _k, _max_fid = load_fm_model('libfm_model_file')
    47 
    48 # max feature_id
    49 n = _max_fid
    50 print 'n', n
    51 
    52 # vector dimension
    53 k = _k
    54 print 'k', k
    55 
    56 # write fm algorithm
    57 w0 = tf.constant(_w0)
    58 w1c = tf.constant([_wj.get(fid, 0) for fid in xrange(n)], shape=[n])
    59 w1 = tf.Variable(w1c)
    60 #print 'w1', w1
    61 
    62 vec = []
    63 for fid in xrange(n):
    64     vec.append(_v.get(fid, [0]*k))
    65 w2c = tf.constant(vec, shape=[n,k])
    66 w2 = tf.Variable(w2c)
    67 print 'w2', w2
    68 
    69 # inputs
    70 x = tf.placeholder(tf.string, [None])
    71 batch = tf.shape(x)[0]
    72 x_s = tf.string_split(x)
    73 inds = tf.stack([tf.cast(x_s.indices[:,0], tf.int64), tf.string_to_number(x_s.values, tf.int64)], axis=1)
    74 x_sparse = tf.sparse.SparseTensor(indices=inds, values=tf.ones([tf.shape(inds)[0]]), dense_shape=[batch,n])
    75 x_ = tf.sparse.to_dense(x_sparse)
    76 
    77 w2_rep = tf.reshape(tf.tile(w2, [batch,1]), [-1,n,k])
    78 print 'w2_rep', w2_rep
    79 
    80 x_rep = tf.reshape(tf.tile(tf.reshape(x_, [batch*n, 1]), [1,k]), [-1,n,k])
    81 print 'x_rep', x_rep
    82 x_rep2 = tf.square(x_rep)
    83 
    84 #print tf.multiply(w2_rep,x_rep)
    85 #print tf.reduce_sum(tf.multiply(w2_rep,x_rep), axis=1)
    86 q = tf.square(tf.reduce_sum(tf.multiply(w2_rep, x_rep), axis=1))
    87 h = tf.reduce_sum(tf.multiply(tf.square(w2_rep), x_rep2), axis=1)
    88 
    89 y = w0 + tf.reduce_sum(tf.multiply(x_, w1), axis=1) +
    90     1.0/2 * tf.reduce_sum(q-h, axis=1)
    91 
    92 saver = tf.train.Saver()
    93 with tf.Session() as sess:
    94     sess.run(tf.global_variables_initializer())
    95     #a = sess.run(y, feed_dict={x_:x_train,y_:y_train,batch:70})
    96     #print a
    97     save_path = "./model.ckpt"
    98     tf.saved_model.simple_save(sess, save_path, inputs={"x": x}, outputs={"y": y})

    参考:

    https://blog.csdn.net/u010159842/article/details/78789355 (开头借鉴此文,但其有不少细节错误)

    https://www.tensorflow.org/guide/saved_model

    http://nowave.it/factorization-machines-with-tensorflow.html

  • 相关阅读:
    Java虚拟机详解(二)------运行时内存结构
    Java虚拟机详解(一)------简介
    分布式任务调度平台XXL-JOB搭建教程
    Kafka 详解(三)------Producer生产者
    服务器监控异常重启服务并发送邮件
    超详细的Linux查找大文件和查找大目录技巧
    linux清理磁盘空间
    Magent实现Memcached集群
    Nginx反爬虫: 禁止某些User Agent抓取网站
    redis集群搭建详细过程
  • 原文地址:https://www.cnblogs.com/yaoyaohust/p/10472780.html
Copyright © 2020-2023  润新知