24

使用TensorFlow2预测国内疫情结束时间

 4 years ago
source link: https://flashgene.com/archives/108633.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

国内的新冠肺炎疫情从发现至今已经持续3个多月了,这场起源于吃野味的灾难给大家的生活造成了诸多方面的影响。

有的同学是收入上的,有的同学是感情上的,有的同学是心理上的,还有的同学则是体重上的。

ryIveqV.jpg!web

那幺国内的新冠肺炎疫情何时结束呢?什幺时候我们才可以重获自由呢?

本篇文章将利用TensorFlow2.0建立时间序列RNN模型,对国内的新冠肺炎疫情结束时间进行预测。

7JbYFjn.png!web

一,准备数据

本文的数据集取自tushare,获取该数据集的方法参考了以下文章。

《https://zhuanlan.zhihu.com/p/109556102》

公众号后台回复关键字:疫情预测,获取本文数据集和对应源码。

import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import tensorflow as tf 
from tensorflow.keras import models,layers,losses,metrics,callbacks
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
df = pd.read_csv("./data/covid-19.csv",sep = "\t")
df.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
plt.xticks(rotation=60)

yeMjMfn.png!web

dfdata = df.set_index("date")
dfdiff = dfdata.diff(periods=1).dropna()
dfdiff = dfdiff.reset_index("date")
dfdiff.plot(x = "date",y = ["confirmed_num","cured_num","dead_num"],figsize=(10,6))
plt.xticks(rotation=60)
dfdiff = dfdiff.drop("date",axis = 1).astype("float32")

zuQFbqy.png!web

#用某日前8天窗口数据作为输入预测该日数据
WINDOW_SIZE = 8
def batch_dataset(dataset):
    dataset_batched = dataset.batch(WINDOW_SIZE,drop_remainder=True)
    return dataset_batched
ds_data = tf.data.Dataset.from_tensor_slices(tf.constant(dfdiff.values,dtype = tf.float32)) \
   .window(WINDOW_SIZE,shift=1).flat_map(batch_dataset)
ds_label = tf.data.Dataset.from_tensor_slices(
    tf.constant(dfdiff.values[WINDOW_SIZE:],dtype = tf.float32))
#数据较小,可以将全部训练数据放入到一个batch中,提升性能
ds_train = tf.data.Dataset.zip((ds_data,ds_label)).batch(38).cache()

二,定义模型

Keras接口有以下3种方式构建模型:使用Sequential按层顺序构建模型,使用函数式API构建任意结构模型,继承Model基类构建自定义模型。

此处选择使用函数式API构建任意结构模型。

#考虑到新增确诊,新增治愈,新增死亡人数数据不可能小于0,设计如下结构
class Block(layers.Layer):
    def __init__(self, **kwargs):
        super(Block, self).__init__(**kwargs)
    def call(self, x_input,x):
        x_out = tf.maximum((1+x)*x_input[:,-1,:],0.0)
        return x_out
    def get_config(self):  
        config = super(Block, self).get_config()
        return config
tf.keras.backend.clear_session()
x_input = layers.Input(shape = (None,3),dtype = tf.float32)
x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x_input)
x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)
x = layers.LSTM(3,return_sequences = True,input_shape=(None,3))(x)
x = layers.LSTM(3,input_shape=(None,3))(x)
x = layers.Dense(3)(x)
#考虑到新增确诊,新增治愈,新增死亡人数数据不可能小于0,设计如下结构
#x = tf.maximum((1+x)*x_input[:,-1,:],0.0)
x = Block()(x_input,x)
model = models.Model(inputs = [x_input],outputs = [x])
model.summary()

nyIJ3m7.png!web

三,训练模型

训练模型通常有3种方法,内置fit方法,内置train_on_batch方法,以及自定义训练循环。此处我们选择最常用也最简单的内置fit方法。

注:循环神经网络调试较为困难,需要设置不同的学习率反复尝试,以让Loss下降收敛。

#自定义损失函数,考虑平方差和预测目标平方的比值
class MSPE(losses.Loss):
    def call(self,y_true,y_pred):
        err_percent = (y_true - y_pred)**2/(tf.maximum(y_true**2,1e-7))
        mean_err_percent = tf.reduce_mean(err_percent)
        return mean_err_percent
    def get_config(self):
        config = super(MSPE, self).get_config()
        return config
import datetime
optimizer = tf.keras.optimizers.Adam(learning_rate=0.01)
model.compile(optimizer=optimizer,loss=MSPE(name = "MSPE"))
logdir = "./data/keras_model/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tb_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
#如果loss在100个epoch后没有提升,学习率减半。
lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor="loss",factor = 0.5, patience = 100)
#当loss在200个epoch后没有提升,则提前终止训练。
stop_callback = tf.keras.callbacks.EarlyStopping(monitor = "loss", patience= 200)
callbacks_list = [tb_callback,lr_callback,stop_callback]
history = model.fit(ds_train,epochs=500,callbacks = callbacks_list)

32aei2I.png!web

四,评估模型

评估模型一般要设置验证集或者测试集,由于此例数据较少,我们仅仅可视化损失函数在训练集上的迭代情况。

%matplotlib inline
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
def plot_metric(history, metric):
    train_metrics = history.history[metric]
    epochs = range(1, len(train_metrics) + 1)
    plt.plot(epochs, train_metrics, 'bo--')
    plt.title('Training '+ metric)
    plt.xlabel("Epochs")
    plt.ylabel(metric)
    plt.legend(["train_"+metric])
    plt.show()

QzUFJj6.png!web

五,使用模型

此处我们使用模型预测疫情结束时间,即 新增确诊病例为0 的时间。同时,我们也可以预测新增治愈人数为0,以及新增死亡人数为0的时间作为国内疫情结束时间的参考。

#使用dfresult记录现有数据以及此后预测的疫情数据
dfresult = dfdiff[["confirmed_num","cured_num","dead_num"]].copy()
dfresult.tail()

nYzIR3i.png!web

#预测此后100天的新增走势,将其结果添加到dfresult中
for i in range(100):
    arr_predict = model.predict(tf.constant(tf.expand_dims(dfresult.values[-38:,:],axis = 0)))
    dfpredict = pd.DataFrame(tf.cast(tf.floor(arr_predict),tf.float32).numpy(),
                columns = dfresult.columns)
    dfresult = dfresult.append(dfpredict,ignore_index=True)
dfresult.query("confirmed_num==0").head()
# 第55天开始新增确诊降为0,第45天对应3月10日,也就是10天后,即预计3月20日新增确诊降为0
# 注:该预测偏乐观

Y77RJvB.png!web

第55天开始新增确诊降为0,第45天对应3月10日,也就是10天后,即预计3月20日新增确诊降为0。该预测偏乐观。

dfresult.query("cured_num==0").head()
# 第164天开始新增治愈降为0,第45天对应3月10日,也就是大概4个月后,即7月10日左右全部治愈。
# 注: 该预测偏悲观,并且存在问题,如果将每天新增治愈人数加起来,将超过累计确诊人数。

yeEFfeM.png!web

第164天开始新增治愈降为0,第45天对应3月10日,也就是大概4个月后,即7月10日左右全部治愈。

注: 该预测偏悲观,并且存在问题,如果将每天新增治愈人数加起来,将超过累计确诊人数。

dfresult.query("dead_num==0").head()
# 第60天开始,新增死亡降为0,第45天对应3月10日,也就是大概15天后,即20200325
# 该预测较为合理

6JZ3Ezu.png!web

第60天开始,新增死亡降为0,第45天对应3月10日,也就是大概15天后,即20200325。

注: 该预测较为合理。

于是,根据我们的模型对历史数据的学习,预测国内的新冠肺炎疫情的新增确诊和新增死亡人数有望在3月底基本结束。

Mjiue2Z.jpg!web

当然,本预测主要基于历史数据,与股票的技术分析有着类似的性质,结论仅供参考,在疫情完全结束之前,小伙伴们切不可放松警惕,给新冠病毒送人头。

QRVvM3M.png!web

六,保存模型

model.save('./data/tf_model_savedmodel', save_format="tf")
print('export saved model.')
model_loaded = tf.keras.models.load_model('./data/tf_model_savedmodel',compile=False)
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model_loaded.compile(optimizer=optimizer,loss=MSPE(name = "MSPE"))
model_loaded.predict(ds_train)

qAn2UnQ.jpg!web


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK