Multiple quantile regression with Tensorflow - part1
Tensorflow로 Multiple Quantiles을 추정하며 crossing problem을 방지하는 방법
TL;DR
이번 글에서는 Tensorflow를 활용하여 동시에 Multiple quantile을 추정하며, Cannon의 접근법을 사용한 Crossing problem을 방지할 수 있는 방법에 대해 서술한다. 시간 없는 사람을 위해 아래 내용은 모두 RektPunk/mcqrnn-tf2 에서 확인할 수 있다.
Preprocessing
먼저, 아래처럼 학습 데이터를 입력 quantile들의 개수만큼 복제하는 DataTransformer
class를 만들었다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# mcqrnn/transform.py
from typing import Union, Tuple
import numpy as np
def _mcqrnn_transform(
x: np.ndarray,
taus: np.ndarray,
y: Union[np.ndarray, None] = None,
) -> Tuple[np.ndarray, ...]:
"""
Transform x, y, taus into the trainable form
Args:
x (np.ndarray): input
y (np.ndarray): output
taus (np.ndarray): quantiles
Return
Tuple[np.ndarray]: transformed x, y
"""
_len_taus = len(taus)
_len_x = len(x)
_len_data = _len_x * _len_taus
x_trans = np.repeat(x, _len_taus, axis=0).astype("float32")
taus_trans = np.tile(taus, _len_x).reshape((_len_data, 1)).astype("float32")
if y is not None:
y_trans = np.repeat(y, _len_taus, axis=0).astype("float32")
y_trans = y_trans.reshape((_len_data, 1))
return x_trans, y_trans, taus_trans
else:
return x_trans, taus_trans
class DataTransformer:
"""
A class to transform data into trainable form.
Args:
x (np.ndarray): input
taus (np.ndarray): quantiles
y (Union[np.ndarray, None]): output
Methods:
__call__:
Return Tuple[np.ndarray, ...]:
transform(input_taus: np.ndarray):
Return transformed x with given input_taus
"""
def __init__(
self,
x: np.ndarray,
taus: np.ndarray,
y: Union[np.ndarray, None] = None,
):
self.x = x
self.y = y
self.taus = taus
self.x_trans, self.y_trans, self.tau_trans = _mcqrnn_transform(
x=self.x, y=self.y, taus=self.taus
)
def __call__(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
return self.x_trans, self.y_trans, self.tau_trans
def transform(self, x: np.ndarray, input_taus: np.ndarray) -> np.ndarray:
input_taus = input_taus.astype("float32")
return _mcqrnn_transform(
x=x,
taus=input_taus,
)
Layers
다음으로, 각각의 Layers에 대해서 정의한다. Cannon의 연구에서는 Input에서 First hidden layer 에서는 quantile에 대해서만 exponential function 을 사용하여 양수로 고정하고, 뒤쪽의 모든 layer와 layer 사이의 weight는 exponential function을 사용하여 양수로 고정한다. 아래의 예시처럼 weight를 양수로 만들어 다음의McqrnnInputDense
, McqrnnDense
, McqrnnOutputDense
를 구현할 수 있다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# mcqrnn/models.py
from typing import Callable
import numpy as np
import tensorflow as tf
class McqrnnInputDense(tf.keras.layers.Layer):
"""
Mcqrnn Input dense network
Args:
out_features (int): the number of nodes in first hidden layer
activation (Callable): activation function e.g. tf.nn.relu or tf.nn.sigmoid
Methods:
build:
Set weight shape for first call
call:
Return dense layer with input activation and features
"""
def __init__(
self,
out_features: int,
activation: Callable,
**kwargs,
):
super(McqrnnInputDense, self).__init__(**kwargs)
self.out_features = out_features
self.activation = activation
def build(
self,
input_shape,
):
self.w_inputs = tf.Variable(
tf.random.normal([input_shape[-1], self.out_features]),
name="w_inputs",
)
self.w_tau = tf.Variable(
tf.random.normal([1, self.out_features]),
name="w_tau",
)
self.b = tf.Variable(
tf.zeros([self.out_features]),
name="b",
)
def call(self, inputs, tau):
outputs = (
tf.matmul(inputs, self.w_inputs)
+ tf.matmul(tau, tf.exp(self.w_tau))
+ self.b
)
return self.activation(outputs)
class McqrnnDense(tf.keras.layers.Layer):
"""
Mcqrnn dense network
Args:
dense_features (int): the number of nodes in hidden layer
activation (Callable): activation function e.g. tf.nn.relu or tf.nn.sigmoid
Methods:
build:
Set weight shape for first call
call:
Return dense layer with activation
"""
def __init__(
self,
dense_features: int,
activation: Callable,
**kwargs,
):
super(McqrnnDense, self).__init__(**kwargs)
self.dense_features = dense_features
self.activation = activation
def build(
self,
input_shape,
):
self.w = tf.Variable(
tf.random.normal([input_shape[-1], self.dense_features]),
name="w",
)
self.b = tf.Variable(
tf.zeros([self.dense_features]),
name="b",
)
def call(
self,
inputs: np.ndarray,
):
outputs = tf.matmul(inputs, tf.exp(self.w)) + self.b
return self.activation(outputs)
class McqrnnOutputDense(tf.keras.layers.Layer):
"""
Mcqrnn Output dense network
Methods:
build:
Set weight shape for first call
call:
Return dense layer with activation
"""
def __init__(
self,
**kwargs,
):
super(McqrnnOutputDense, self).__init__(**kwargs)
def build(
self,
input_shape,
):
self.w = tf.Variable(
tf.random.normal([input_shape[-1], 1]),
name="w",
)
self.b = tf.Variable(tf.zeros([1]), name="b")
def call(
self,
inputs: np.ndarray,
):
outputs = tf.matmul(inputs, tf.exp(self.w)) + self.b
return outputs
Model
이렇게 구성된 Layer들을 가지고 모델을 구성할 수 있다. 예제로, Cannon의 연구에서 제안된 Hidden layer가 한 개인 Mcqrnn
를 다음과 같이 구성했다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# mcqrnn/models.py
from typing import Callable
import numpy as np
import tensorflow as tf
from mcqrnn.layers import (
McqrnnInputDense,
McqrnnDense,
McqrnnOutputDense,
)
class Mcqrnn(tf.keras.Model):
"""
Mcqrnn simple structure
Note that the middle of dense network can be modified with McqrnnDense
Args:
out_features (int): the number of nodes in first hidden layer
dense_features (int): the number of nodes in hidden layer
activation (Callable): activation function e.g. tf.nn.relu or tf.nn.sigmoid
Methods:
call:
Return dense layer with input activation and features
"""
def __init__(
self,
out_features: int,
dense_features: int,
activation: Callable = tf.nn.relu,
**kwargs,
):
super(Mcqrnn, self).__init__(**kwargs)
self.input_dense = McqrnnInputDense(
out_features=out_features,
activation=activation,
)
self.dense = McqrnnDense(
dense_features=dense_features,
activation=activation,
)
self.output_dense = McqrnnOutputDense()
def call(
self,
inputs: np.ndarray,
tau: np.ndarray,
):
x = self.input_dense(inputs, tau)
x = self.dense(x)
outputs = self.output_dense(x)
return outputs
Loss
마지막으로 최적화의 타겟이 될 목적함수를 TiltedAbsoluteLoss
로 정의했다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# mcqrnn/loss.py
from typing import Union
import tensorflow as tf
import numpy as np
class TiltedAbsoluteLoss(tf.keras.losses.Loss):
"""
Tilted absolute loss function or check loss
Args:
y_true (Union[np.ndarray, tf.Tesnsor]): train target value
y_pred (Union[np.ndarray, tf.Tesnsor]): pred value
tau (Union[np.ndarray, tf.Tesnsor]): quantiles
Return:
tf.Tensor: tilted absolute loss
"""
def __init__(self, tau: Union[np.ndarray, tf.Tensor], **kwargs):
super(TiltedAbsoluteLoss, self).__init__(**kwargs)
self._one = tf.cast(1, dtype=tau.dtype)
self._tau = tf.cast(tau, dtype=tau.dtype)
def call(
self,
y_true: Union[np.ndarray, tf.Tensor],
y_pred: Union[np.ndarray, tf.Tensor],
) -> tf.Tensor:
error = y_true - y_pred
_loss = tf.math.maximum(self._tau * error, (self._tau - self._one) * error)
return tf.reduce_mean(_loss)
Conclusion
이제는 모델의 모든 구성요소를 갖췄다. 다음으로 모델을 학습할 차례이다. 하지만, 학습과 시각화까지 소개하자니 코드가 너무 길어질 것 같다… 따라서, 해당 파트는 다음에 다루도록 하겠다..