【TensorFlow小记】线性回归案例

本文是一个TensorFlow入门小案例,通过梯度下降来解决线性回归问题。

一、环境

  • 开发环境
    • Windows
  • Python版本
    • Python 3.5.4
  • pip包
    • tensorflow==1.5.0
    • numpy==1.16.3
    • matplotlib==3.0.3

二、案例介绍

  线性回归一般用于预测,比如:股票涨跌。
  梯度下降是机器学习中最核心的优化算法。
  目的:
  用TensorFlow和梯度下降来解决线性回归问题,找到最佳拟合(穿过所有点,平均到所有点距离最短)的一条值线。之后,给定x值,这条直线可以预测y值。

三、完整代码

  该案例的完整代码及相关注释如下LR_using_GD.py:

# -*- coding: utf-8 -*-
"""
用梯度下降的优化方法来快速解决线性回归问题
需求:
对于直线y=Wx+b,随机生成100个点,围绕在y=Wx+b直线周围;
建立回归模型,学习并训练出W和b,能够更好地拟合出这些数据点
"""

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

# 构建数据
points_num = 100
vectors = []
# 用numpy的正太随机分布函数生成100个点
# 这些点的(x,y)坐标值对应线性方程 y = 0.1 * x + 0.2
# 权重(Weight)0.1,偏差(Bias)0.2
for i in range(points_num):
    x1 = np.random.normal(0.0, 0.66)  # # 0为均值,0.66为标准差
    y1 = 0.1 * x1 + 0.2 + np.random.normal(0.0, 0.04)
    vectors.append([x1, y1])

# 生成一些样本
x_data = [v[0] for v in vectors]  # 真实的点的x坐标
y_data = [v[1] for v in vectors]  # 真实的点的y坐标

# 图像1:展示100个随机数据点
plt.plot(x_data, y_data, 'r*', label='Original data')  # 红色星形的点
plt.title('Linear Regression using Gradient Descent')
plt.legend()
plt.show()

# 构建线性回归模型
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')  # 初始化权重(Weight):生成1维的W矩阵,取值是[-1,1]之间的随机数
b = tf.Variable(tf.zeros([1]), name='b')  # 初始化偏差(Bias):生成1维的b矩阵,初始值是0
y = W * x_data + b  # 模型经过计算得出预估值y

# 定义loss function(损失函数)或cost function(代价函数)
# 以预估值y和实际值y_data之间的均方误差作为损失
# 即对Tensor的所有维度计算 ((y - y_data) ^ 2) 之和 / N ,此处N为100
loss = tf.reduce_mean(tf.square(y - y_data), name='loss')

# 用梯度下降的优化器来优化我们的loss function
optimizer = tf.train.GradientDescentOptimizer(0.5)  # 设置学习率 0.5
# 训练的过程就是最小化这个误差值
train = optimizer.minimize(loss, name='train')

# 创建会话
sess = tf.Session()

# 初始化数据流图中的所有变量
init = tf.global_variables_initializer()
sess.run(init)

# 打印初始化的W和b是多少
print('Initial: Loss=%f, [Weight=%f Bias=%f]' % (sess.run(loss), sess.run(W), sess.run(b)))

# 训练 20 步
for step in range(20):
    # 优化每一步
    sess.run(train)
    # 打印出每一步的损失,权重和偏差
    print('Step=%d, Loss=%f, [Weight=%f Bias=%f]' % (step, sess.run(loss), sess.run(W), sess.run(b)))

# 图像2:绘制所有的点并且绘制出最佳拟合的直线
plt.plot(x_data, y_data, 'r*', label='Original data')  # 红色星形的点
plt.title('Linear Regression using Gradient Descent')
plt.plot(x_data, sess.run(W) * x_data + sess.run(b), label='Fitted line')  # 拟合的线
plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.show()

# 关闭会话
sess.close()

四、运行结果

  这里我使用的IDE是PyCharm,直接贴上运行结果图:
tensorflow_lr_using_gd.png


  目录