Deep Learning

Keras: 用 Unet 學習資料讀取及抽換資料的思路

以前自已摸索深度學習時,都是使用現成的資料集去實作相關程式碼,例如 MNIST 數字集、Cifar10 …等,網路上也充斥各種範例程式,雖然可以很快去完成一支訓練程式,很輕鬆的建構出神經網路架構 (ctrl+c ctrl+v),殊不知,讀資料才是最重要的XD

也因為對於資料讀取這件事不夠了解,所以要換成自己的資料就卡關了,也不知道如何製作輸入影像及標記資料,間接導致換資料變得很困難,至少對於一些本身背景就不是寫程式的人更是如此,也因此我們除了學習相關 API 之外,也應該學習如何建立及讀取資料。

本文使用 tensorflow2 + keras API 實作 Unet 架構,並嘗試用最簡單的方式來學習 tf.data.Dataset 這個資料讀取的方法。

關於深度學習的流程

在深度學習中,資料通常可以分成 Input data、Output data 及對應 Ground true data,就是所謂的標記資料,其中 Input data 和 Ground true data 都是要自行準備的,對應上方流程圖就是 Input X 及 Input Y

補充:整個訓練過程大概是 輸入資料>>神經網路(Forward)>>預測結果>>計算預測與標記誤差(loss function)>>優化神經網路權重(Backward)

思考資料讀取的參數

本文以分割架構作為資料讀取範例,如果想做分類資料的讀取程式,可以參考另一篇文章

雖說使用 Unet 作為架構,但不一定要建立分割資料,而是先思考這些資料的相關參數,並將這些參數資訊轉成讀取資料的程式碼,確定資料流的方法以後再來決定如何建立資料庫。

例如:

一開始先以 影像大小 去思考

  • 輸入是一張 512×512 影像
  • 輸出是一張與輸入同樣大小 512×512 的影像

因此可以得知輸入輸出都會是 512×512 的矩陣

接著再去思考 data 裡面的數值

  • 輸入影像使用灰階影像並正規化至 0-1
  • 輸出一張與原圖相同的灰階影像 (輸出一樣的東西?)

由於前面已經先想好會輸出相同大小的矩陣,所以我們可以先使用任何假資料,先去完成資料讀取,將來根據不同應用抽換標記資料即可

寫好資料讀取方式,再去思考輸出的資料,可能應用如下

去模糊、分割、風格轉換、etc.

可以發現這些應用都是將輸入影像重建出另一種形式的結果,所以只要資料讀取方法建立,其實都是大同小異

實作讀取程式

KERAS 其實也有提供讀取資料的 API 可以使用,但是要修改一些東西時就會很麻煩,因為都封裝起來了,但有了前面概念之後,就會發現任何一支讀資料程式都是差不多概念,所以其實是可以自行建構一套讀資料的程式XD。

自從有了 TensorFlow 2 之後,多了許多 API 可以使用,其中 tf.data 就是讓我們有自行決定資料處理的彈性外,又提供一些共通的功能,讓我們不用自己重新造輪子。

如果今天使用別人專案進行實驗,大部分都會附上抽換資料的方法或步驟,但如果有一些概念以後,要使用別人的程式碼也更容易,也比較知道要怎麼進行修改來增加自己的功能。

不管使用什麼方式,都先建立資料的路徑清單!!!

在講解 tf.data 程式之前,我們先來建立資料清單。大部分資料集下載下來,其實大多都會提供讀取清單,通常是訓練、驗證、測試清單。建立清單一方面是確保資料順序一致性,另一面是對每次結果進行比較,對於訓練結果驗證格外重要。

而本文資料路徑清單是利用 glob 來建立,這個清單都會給 tf.data.Dataset.from_tensor_slices()tf.data.Dataset.from_generator() 使用。

# 使用 glob 來獲得資料路徑
# 利用正則表示式來辨識圖片路徑
#
# data\in\ -> 輸入影像(image)的資料夾
#
tra_data_in = glob.glob(r"data\in\*.png")
#
# data\out\ -> 標記影像(mask)的資料夾
#
tra_data_out = glob.glob(r"data\out\*.png")

# 建立 dict 來存放路徑
# 給 from_tensor_slices() 使用
tra_data = {}
tra_data["in"] = tra_data_in
tra_data["out"] = tra_data_out

接下來,我會介紹兩種方式去讀資料

  • tf.data.Dataset.from_tensor_slices()
  • tf.data.Dataset.from_generator()

tf.data.Dataset.from_tensor_slices() + map()

  1. 建立讀取影像的 function
def parse_fn(dataset, **kwargs):
    # 分離 dataset dict
    path_x = dataset["in"]    # for inp image
    path_y = dataset["out"]   # for out mask

    # 讀取 image (使用 tf IO function)
    x = tf.io.read_file(path_x)
    x = tf.io.decode_image(x, channels=1, expand_animations=False)
    y = tf.io.read_file(path_y)
    y = tf.io.decode_image(y, channels=1, expand_animations=False)

    # 對影像進行正規化,及增加影像通道
    x = tf.cast(x, tf.float32) / 255.0
    x = tf.expand_dims(x, axis=-1)
    y = tf.cast(y, tf.float32) / 255.0
    y = tf.expand_dims(y, axis=-1)

    return x, x # 回傳資料
  1. 將資料清單 tra_data 輸入至 from_tensor_slices() 並建立 tra_ds 物件
tra_ds = Dataset.from_tensor_slices(tra_data)
  1. 使用 map() function 來嵌入讀取影像的 function
autotune = tf.data.experimental.AUTOTUNE  # 自動模式
tra_ds = tra_ds.map(lambda ds: parse_fn(ds), num_parallel_calls=autotune)
  1. 為 tra_ds Dataset 物件設定 shuffle、batch size
tra_ds = tra_ds.shuffle(20).batch(1)

特點

  • 可以使用 batch()shuffle()map() …等方法來設定資料訓練的參數
  • map() 來建立資料讀取 funciton,可以避免總資料量過大而無法放進記憶體
  • 使用 map() 可以加入各種預處理的運算 (用 tf 的 API 來運算)
  • 使用 tf.py_function() 可以整合一些 python 的運算,例如各種資料擴增處理 (但官方建議使用 tf API)

tf.data.Dataset.from_generator()

  1. 建立 generator function

這邊使用 python 內嵌函數來添加自訂運算

def create_dataset(data_x, data_y, data_shape):
    def _generator(data_x, data_y, data_shape):   # 直接導入父函數的引數
        H, W = data_shape                         # 取得影像長寬

        for n in range(len(data_x)):              # 用 for 迴圈去讀取資料
            if isinstance(data_x[n], bytes):      # 處理字串
                x = data_x[n].decode("utf-8")
            if isinstance(data_y[n], bytes):      # 處理字串
                y = data_y[n].decode("utf-8")

            # 讀取影像
            X = cv2.imread(x, 0)
            X = cv2.resize(X, (H,W))
            x = X / 255.0
            x = tf.expand_dims(x, axis=-1)
            X.astype('float32').reshape(H,W,1)

            # 讀取影像
            Y = cv2.imread(y, 0)
            Y = cv2.resize(Y, (H,W))
            Y = Y / 255.0
            Y.astype('float32').reshape(H,W,1)
            yield X, Y

    # 回傳 Dataset 物件
    return tf.data.Dataset.from_generator(
        _generator, args=[data_x, data_y, data_shape],
        output_signature=(tf.TensorSpec(shape=data_shape, dtype=tf.float32),
                          tf.TensorSpec(shape=data_shape, dtype=tf.float32))
    )
  1. 建立 Dataset 物件

一般都會是在 generator 中處理批次(batch)資料讀取,但這邊直接用 tf.data API 來處理

tra_ds = create_dataset(tra_data_in, tra_data_out, (512,512)).batch(1)

特點

  • 與上述方法一樣,可以避免記憶體不足問題
  • 能夠更加輕易的導入 python 的運算 (不用使用 tf API,效能會受到 Python GIL 限制)
  • 可以將 Keras API 的 preprocessing.image.ImageDataGenerator 包裝成 tf.data.Dataset 連結

Refer

留下一個回覆

發佈留言必須填寫的電子郵件地址不會公開。 必填欄位標示為 *