【Conference Projector】OpenAI API を使って CVPR 2023 全体を眺めるWebサイトを作成した

概要

  • CVPR 2023 会議全体を可視化したグラフを眺めながら論文検索できるWebサイトを作成したので紹介します。

  • 会議に採択された論文全体を可視化したグラフから、 カテゴリやアプリケーションが近い論文を探せます。

  • テキスト検索ではない方法で、広い視野で論文を探せます。

  • 会議全体で盛り上がっている分野や、逆にニッチな分野を把握することもにも役立ちます。

  • 研究テーマを模索している方や、広い視野で業界動向を知りたい方におすすめです。

yuukicammy--conference-projector-wrapper.modal.run

はじめに

こちらの記事にインスパイアされました。

zenn.dev

このシステムは文章やキーワードから論文を検索しますが、具体的なキーワードというよりは会議全体を俯瞰して調査したいなと思ったため今回のWebサイトを作りました。

cvpaper.challengeの活動に興味のある方などは関心が近いかもしれません。

Conference Projector で何ができるか

会議全体の論文を投影した散布図から、ノードをクリックすることでその論文の概要やその論文に近い論文を閲覧できます。 論文のカテゴリ、アプリケーション、タイトル、アブストラクトの観点から会議全体の俯瞰や近しい論文を探すことができます。

ここでカテゴリ*1とは、論文が取り組んでいる特定の研究テーマや研究領域を示します。例えば物体検出や露光補正などです。 アプリケーションとは、論文の技術が想定している応用先を示します。例えば自動運転などです。

使い方はこんな感じです。


www.youtube.com

カーソルを合わせるとの論文情報が表示されます。カーソルを移動させながら素早く多数の論文タイトルや研究カテゴリを把握できます。

Top Screen 会議全体の論文を可視化したページ

ノードをクリックすると、その論文に関する情報が閲覧できます。選択中の論文に近い論文も閲覧できます。

Paper Info Screen 選択した論文とその類似論文を閲覧するページ

システム概要

下記のツールを利用してシステムを構築しました。

下記の手順でシステムを構築しています。

  • (1) スクレイピング   
  • (2) カテゴリ、アプリケーションなどのテキスト生成   
  • (3) Embedding  
  • (4) PDFからの画像抽出   
  • (5) 次元圧縮   
  • (6) K-D Tree構築   
  • (7) Webサイトデプロイ   

(1)から(6)までが前処理で、データを揃えたのちにWebサイトをデプロイしています。 Webサイトは前処理で揃えたデータのみを使っており、新たに OpenAI APIなどへのリクエストが発生することはありません。 これによりWebサイトのレイテンシを小さくするとともに、ランニング費用 (APIリクエスト費用) を抑えています。


pipeline データ前処理のパイプライン

論文を散布図に投影し類似の論文を検索できるようにするまでには、下記のようなプロセスをとっています。

projection data processing 論文を散布図に投影し、近傍探索するまでの処理

詳しいプロセスは次の章で述べます。

実装詳細

(1) スクレイピング

会議に採択された全ての論文情報を抽出します。 CVPR 2023 Open Access のタイトル一覧ページ を最初の入力として論文情報をスクレイピングしました。 抽出したデータはAzure Cosmos DBに保存しています。

ここで抽出する論文情報は下記4つです。

(2) カテゴリ、アプリケーションなどのテキスト生成

会議のWebサイトから抽出できない情報 (カテゴリなど) をOpenAI Chat APIで生成しました。 こちらの方法と同じように、出力のフォーマットを指定する目的でFunction callingを使っています。 モデルには gpt-3.5-turbo-0613 を使いました。2023年6月時点でFunction callingを利用できるモデルがこれとgpt-4-0613しかなく、自分のアカウントではGPT-4のAPIは使えないためです。

論文のタイトルとアブストラクトを含むプロンプトから、 論文のカテゴリ、アプリケーションなど下記6項目をテキスト生成しました。

  • 論文の概要
  • 先行研究より優れている点
  • 提案手法のポイント
  • 実験結果
  • カテゴリ
  • アプリケーション

各項目を日本語と英語で生成するため、計12項目の文章をGPTで生成しています。 利用したプロンプトとFunction callingのスキーマは下記をご覧ください。

プロンプトはバッチ化せず、1リクエストで1論文-12項目のテキストを生成しています。 OpenAI Tokenizer で測ったところ、リクエストの入力トークン長は500程度、出力トークン長は2000程度でした。2359件のリクエストをして各論文12項目のテキストを揃えます。

この GPT-3.5 Turbo によるテキスト生成プロセスが最も時間とお金を消費しました。 ここに書かれている 通りGPT-3.5 Turboはレスポンスが遅く、頻繁に過負荷による 503 Error が出ます。 レスポンスを得るまで20秒程度かかり、2359件揃えるのに最低16時間かかります。 またFunction callingで指定している12項目全ての出力を必須にしても、一部あるいは全ての項目で生成されないされない場合があります。 12項目が全て生成されるまで数回リトライしました。 さらに入力・出力の試行錯誤もあります。間違えたリクエストを大量に送って時間と数千円を無駄にしたりもしました。

(3) Embedding

各論文に関する4つのテキスト (カテゴリ、アプリケーション、タイトル、アブストラクト) をOpen AI Embeddings APIで Embeddingしました。モデルには text-embedding-ada-002 を使いました。 処理時間を短縮する目的で、GPT best practices - Improving latencies を参考にバッチ化してリクエストしました。最大で20のテキストをバッチ化してリクエストできます。一時的に過負荷エラーが出つづけることもあったのですが、Chat APIのテキスト生成に比べれば素早く全ての処理を完了できました。

(4) PDFからの画像抽出

提案手法の概要を示したような画像を取得したくPDFから画像を抽出したのですが、これが一筋縄ではいきませんでした。 CVPR Open Access に置かれたPDFから図を PyMuPDF で抽出したのですが、良さげな画像があまり抽出できませんでした。*2

試行錯誤の上、下記のプロセスで論文の代表画像を抽出しています。

  1. CVPR Open Access にあるarXiv情報、および、arXivでタイトル検索し、arXivに登録されていればソース一式から最大サイズの画像を取得する。
  2. 1に失敗すれば、CVPR Open Access のPDFをPyMuPDFで解析し、表示領域が最大の画像を取得する。

工夫はしましたが、いい感じの代表画像が取得できている割合は感覚的に20%程度です。 ここは改善したいです。

(5) 次元圧縮

各論文情報を散布図に投影するためにEmbeddingを2次元と3次元に次元圧縮します。 次元圧縮の方法はUMAP、t-SNE、PCAの3手法を採用しました。 次元圧縮したEmbeddingで論文の類似度を測っているため (理由は後述)、ここでの次元圧縮結果が非常に重要です。 Webサイト上でユーザに次元圧縮のハイパーパラメータを変更してもらうことも考えましたが、実装負荷とWebサイトのレイテンシの観点から今はやっていません。

(6) K-D Tree構築

類似の論文を高速に検索するために、事前にEmbeddingをノードとしたK-D Treeを構築しています。 K-D Treeは下記の設定で次元削減した24種類 (4 x 3 x 2) のEmbedding群のそれぞれで構築しています。

Embeddingを次元削減する設定 (組み合わせは24種類):

  • Embedding (x4) : カテゴリ、アプリケーション、タイトル、アブストラク
  • 次元圧縮手法 (x3) : UMAP、t-SNE、PCA
  • 圧縮後の次元 (x2) : 2次元、3次元

OpenAI APIで得た1536次元のEmbeddingのまま類似度を測りTreeを構築すればシンプルですが、それだとユーザが散布図で見ている近傍ノード (論文) と、システムが高次元Embeddingから判断した近傍ノードが異なり、ユーザ体験を損ねます。それゆえ、次元削減した2次元・3次元のEmbeddingからK-D Treeを構築しています。

kNN
高次元で探索された近傍ノードは、2次元投影後に近傍にみえない場合が多い

(7) Webサイトデプロイ

Dashを使って構築しました。DashはPlotly社が開発したPythonフレームワークで、FlaskをベースのWebサイトを簡単に構築できます。Plotlyのグラフを使ったインタラクティブなWebが作れるので今回の利用に適していました。Dashを使ったのは初めてでしたが比較的かんたんにWebサイトを構築できました。

今はModal上にDashをのせるかたちでサーバレスなWebサイトを構築していますが、このインフラ選択はまだ悩みがあります。サーバレスとして作るよりモノリスなWebサイトとして構築した方がレイテンシやトラブル回避の観点で良いかもしれません。またWebサイトをDashを使わず構築した方がレイテンシを小さくできるでしょう。今回は使い慣れているためModalを選んだ*3 ことに加え、HTML/CSSのフロントエンドコーディングはあまり頑張りたくないのでDashを使うという判断でした。

役立つものはできたのか?

広い視野からの論文検索

  • 会議全体の論文を眺めながら読みたいと思う論文を探せるという意味では、広い視野からの論文検索は実現できたと思います。
  • 今のままでは情報がのっぺりしていて、自分にとって関心の高い論文に行き着くまでに手間がかかります。

改善に向けて

  • キーワード検索を組み合わせたり、評価された論文をわかりやすく示すなど論文に辿りつくための手がかりがあった方がよさそうです。

カテゴリ・アプリケーション観点での類似論文検索

  • 特定のカテゴリ・アプリケーションはうまく近傍にまとまっており類似の論文が見つけやすいと感じました。
  • カテゴリ・アプリケーションがうまく推定できていないケースも多く存在します。例えばカテゴリがコンピュータビジョンになっているケースがあります。アプリケーションは長文になりがち、かつ、的を得ないケースが多かったです。
  • カテゴリが的確につけられても、カテゴリ間の近さがうまく表せませんでした。例えば画像の露出補正とノイズ除去は近しいような気がしますが、カテゴリでもアプリケーションでも近づけられませんでした。

改善に向けて

  • プロンプトに記載する具体的なカテゴリとアプリケーションの例を増やすことが必要そうです。
  • アプリケーションの推定はPDF本文も参考にした方がいいかもしれません。
  • 近しい論文をより的確にするために、次元圧縮方法などの工夫が必要です。

会議全体のトレンド把握

Category Analysis

  • Zero-Shot/Few-Shotが高い密度でそれなりに大きなクラスタになっており、トレンドを表していることがわかります。
  • 実はプロンプトにカテゴリの例として「3D人物姿勢推定、3D物体追跡、物体検出, 露出補正」を記載しています。これらは比較的うまくまとめられたものの、他のカテゴリは低密度になることが多いです。
  • 距離が離れているとマイナーなカテゴリとみなしやすいものの、低密度なエリアはマイナーなのかどうなのか判断しづらいです。

改善に向けて

  • プロンプトにカテゴリとアプリケーションの具体例を増やす改善に加え、次元圧縮のチューニングとそれらの効率化も必要です。

おわりに

活躍するツールになるにはまだ課題が多いです。初日はほぼ徹夜で一人ハッカソンとして作って完成させましたが、「後もうちょっと」が続き3週間が過ぎました。当初ハッカソン的に作り始めた時よりは実用的に使えそうなツールになりましたが、多数のバックログがあるのと、Embeddingを作り直すのにお金がかかるという問題があります。

作ってみて実感したのは、LLMをモデル更新に関与せず活用するには、実用上プロンプトエンジニアリングから逃れられないということです。 正直これまでプロンプトエンジニアリングの価値を軽視しており、今回もあまり練られていません。 論文に付与するカテゴリ・アプリケーションのテキストを的確に生成することが可視化全体、つまりユーザ体験に強く影響しており、これらを生成するためのプロンプトが重要だと考えています。具体的には、生成させたいカテゴリやアプリケーションの具体名をプロンプトで示すことに効果がありそうです。 プロンプトエンジニアリングを効果的なチューニング作業の一つとして捉えるようになりました。

yuukicammy--conference-projector-wrapper.modal.run

コードはこちら

github.com

英語ブログはこちら

medium.com

こちらの記事も関心が近いです。

xiangze.hatenablog.com

link.medium.com

*1:「カテゴリ」と表現しましたが、トピックと表現した方が良かったかもしれません。これらの観念を複雑に構造化せずわかりやすく扱うことはなかなか難しいです。

*2:PDF形式の画像を埋め込んでる原稿が多く、PyMuPDFではPDFを画像として抽出できないことが原因と考えています。

*3:Modalが個人的に好きだということも選択理由です。個人開発なのでそういうバイアスも大きいです。

MIT-Adobe FiveK Datasetとツールの紹介

概要

  • 豊富なRAWデータが利用できるMIT-Adobe FiveK Dataset*1を紹介する。
  • MIT-Adobe FiveK Datasetを簡単にダウンロードして使うためのツールを作成したので紹介する。

github.com

MIT-Adobe FiveK Datasetとは

MIT-Adobe FiveK Datasetとは一眼レフカメラで撮影された高精細画像のデータセットである。シーンや被写体、照明など、様々な条件下で撮影されている。データセットには下記が含まれている。

  1. 5, 000枚のRAWデータ

    • DNG形式
    • 1枚あたり10MB程度
    • 全体で35個のカメラモデルで撮影されている (私が調べた限りは)
  2. 25,000枚のレタッチ画像

    • TIFF形式
    • 1枚あたり60MB程度
    • RAWデータ1枚あたり5人の専門家*2がそれぞれレタッチ画像を作成
    • 16bitのProPhoto RGBカラースペースでlossless圧縮
    • レタッチにはAdobe Lightroomを使用
  3. セマンティック情報

    • CSV形式
    • 0.2MB程度
    • 各画像の撮影時間帯、場所、照明、被写体のざっくりとしたカテゴリが付与されている
  4. レタッチの編集履歴

    • Adobe Lightroomのカタログファイル (lrcat形式)
    • 1790MB程度
    • 各画像のレタッチ時のスライダー値や調整履歴が残っている

データのサンプル

Raw (DNG) Expert A Expert B Expert C Expert D Expert E Categories Camera Model
a0001-jmac_
DSC1459.dng
tiff16_a/a0001-jmac_DSC1459 tiff16_b/a0001-jmac_DSC1459 tiff16_c/a0001-jmac_DSC1459 tiff16_d/a0001-jmac_DSC1459 tiff16_e/a0001-jmac_DSC1459 {"location":"outdoor","time": "day","light": "sun_sky","subject": "nature"} Nikon D70
a1384-dvf_095.dng tiff16_a/a1384-dvf_095 tiff16_b/a1384-dvf_095 tiff16_c/a1384-dvf_095 tiff16_d/a1384-dvf_095 tiff16_e/a1384-dvf_095 { "location": "outdoor", "time": "day", "light": "sun_sky", "subject": "nature" } Leica M8
a4607-050801_
080948__
I2E5512.dng
tiff16_a/a4607-050801_080948__I2E5512 tiff16_b/a4607-050801_080948__I2E5512 tiff16_c/a4607-050801_080948__I2E5512 tiff16_d/a4607-050801_080948__I2E5512 tiff16_e/a4607-050801_080948__I2E5512 { "location": "indoor", "time": "day", "light": "artificial", "subject": "people" } Canon EOS-1D Mark II

どのような用途で使えるか

このデータセットを変換することで幅広い用途の機械学習で利用できる。 このデータセットの特徴は、他と比べてRAWデータがあることだと思う。例えば、RAWデータからsRGBへの現像時にホワイトバランスを変えることで、ISPのホワイトバランス調整エラーを模擬したデータセットを作成できる*3

また、高精細なレタッチ画像も数多くあるため、超解像などはRAWデータを使わずともレタッチ画像から適用できると思う。

実際に下記の用途で利用されている*4

  • color enhancement *5 *6
  • white-balance editing *7
  • super resolution *8 *9
  • noise reduction *10 *11
  • underexposure *12
  • overexposure *13

MIT-Adobe FiveK Datasetを簡単に使うためのツール紹介

ツールを作成した背景

MIT-Adobe FiveK Datasetはダウンロードも使うのもやや面倒だという問題がある。 公式ページでMIT-Adobe FiveK Datasetをダウンロードするためのアーカイブが提供されている。だが、アーカイブを解凍後のディレクトリ構造が少々複雑である*14。 またこのアーカイブにはレタッチ画像は含まれておらず、レタッチ画像は個別URLで一枚づつダウンロードするしかない。

加えて、 RAWデータを読み込むOSSとして有名なLibRawでは一部のカメラモデルのRAWデータを正しく読み込むことができない *15。 それゆえカメラモデルを限定して試したいこともあるのだが、各RAWデータ(10MB程度)を読み込むまではそれがどのカメラモデルなのかプログラム内で判断できない。

このようなデータセットの複雑さ・使いづらさはMIT-Adobe FiveK Datasetに限らず、公開データセットではよくあることだと思う。公式提供のものをそのまま使うのは面倒なことが多いため、PyTorchなどのライブラリは有名データセットを簡単にダウンロードして使えるように各データセットクラスを用意している。

同じようにMIT-Adobe FiveK Datasetを簡単にダウンロードして扱えるようにしたかったため、このデータセットのツールを作ることにした。

コード

コードはこちら

MIT-Adobe FiveK Datasetをそのまま機械学習フレームワークで使うことは少ないと思う。多くの場合には、このデータセットに何かしらの前処理 (例えばRAW現像する、ノイズを加える、画像サイズを小さくする、など) をして保存したものを利用すると思う。

基本的にはそのような前処理で利用することを想定してツールを作成した。

使い方

使い方はこんな感じ。

class Preprocess:
    def hello_world(self, item):
        print(f"hello world! the current ID is {item['id']}.")

data_loader = DataLoader(
    MITAboveFiveK(
        root=args.root_dir,
        split="debug",
        download=True, experts=["a", "c"],
        download_workers=4, # multi-process for downloading
        process_fn=Preprocess().hello_world),
    batch_size=None,  # must be `None`
    num_workers=4  # multi-process for pre-processing
)
for item in data_loader:
    # pre-processing has already been performed.
    print(item)

このツールのいいところ

このツールは下記の従来課題に対処した。

  • ダウンロードが面倒
    Pythonコード1行で (MITAboveFiveKクラスのインスタンスを生成するだけで) RAWもTIFFも全てダウンロードできるようにした。

  • ディレクトリ構造が複雑でアクセスしづらい / RAWデータを読み込むまでカメラモデルがわからない
    ディレクトリ構造を再設計し、カメラモデルごとにディレクトリを分けた。各データへアクセスしやすくするために、データごとにファイルパスやカテゴリ等のメタデータをまとめ、DataLoaderのイテレーションで受け取れるようにした。また、データダウンロードのために、メタデータをまとめたJsonファイルを用意した。Jsonファイルの例

 {'categories': {'location': 'outdoor', 'time': 'day', 'light': 'sun_sky', 'subject': 'nature'},
 'id': 1384, 'basename': 'a1384-dvf_095',
 'license': 'Adobe', 'camera': {'make': 'Leica', 'model': 'M8'}, 
 'files': 
  {'dng': '/datasets/MITAboveFiveK/raw/Leica_M8/a1384-dvf_095.dng', 
   'tiff16': {'a': '/datasets/MITAboveFiveK/processed/tiff16_a/a1384-dvf_095.tif', 
    'c': '/datasets/MITAboveFiveK/processed/tiff16_c/a1384-dvf_095.tif'}}, 
}

DataLoaderのイテレーションで受け取るデータ例
(ファイルパスは実行時に決定・追加される)

  • データダウンロードや前処理に時間がかかる
    → データダウンロードはMITAboveFiveKdownload_workersを設定することでマルチプロセス化できる。前処理はDataLoaderのnum_workersを設定することでマルチプロセス化できる。

  • 新しいツールは使い方を覚えるまでが面倒
    → PyTorchのDatasetクラス (を継承したクラス) をDataLoader経由で呼び出すという、PyTorchユーザなら慣れ親しんだ方法で利用できるようにした。データセットへ適用したい前処理用関数をprocess_fnへ渡すことで、DataLoader経由で前処理も適用できる。

おわりに

公開データセットをダウンロードして使えるようにするだけで一晩を溶かす人が減りますように。

*1:V. Bychkovsky, S. Paris, E. Chan, and F. Durand. Learning photographic global tonal adjustment with a database of input / output image pairs. CVPR, 2011.

*2:photography students in an art school

*3:M. Afifi and MS. Brown. Deep white-balance editing. CVPR, 2020.

*4:"color enhancement"のカバー範囲が広すぎて独立した項目になっていなくてすみません。

*5:Y. Zhicheng, H. Zhang, W. Baoyuan, P. Sylvain, and Y. Yizhou. Automatic photo adjustment using deep neural networks. TOG, 2016

*6:J. Park, J. Lee, D. Yoo, and I. Kweon. Distort-and-Recover: Color Enhancement Using Deep Reinforcement Learning. CVPR, 2018.

*7:M. Afifi and MS. Brown. Deep white-balance editing. CVPR, 2020.

*8:MT. Rasheed and D. Shi. LSR: Lightening super-resolution deep network for low-light image enhancement. Neurocomputing, 2022.

*9:X. Xu, Y. Ma, and W. Sun, Towards Real Scene Super-Resolution With Raw Images. CVPR, 2019.

*10:J. Byun, S. Cha, and T. Moon. FBI-Denoiser: Fast Blind Image Denoiser for Poisson-Gaussian Noise. CVPR, 2021.

*11:S. Guo, Z. Yan, K. Zhang, W. Zuo, and L. Zhang. Toward Convolutional Blind Denoising of Real Photographs. CVPR, 2019.

*12:R. Liu, L. Ma, J. Zhang, X. Fan, and Z. Luo. Retinex-Inspired Unrolling With Cooperative Prior Architecture Search for Low-Light Image Enhancement. CVPR, 2021.

*13:M. Afifi, KG. Derpanis, B. Ommer, and MS. Brown. Learning Multi-Scale Photo Exposure Correction, CVPR, 2021.

*14:"HQa1to700", "HQa701to1400"と700枚づつRAWデータが分かれてディレクトリに保存されているかとお思えば、"HQa1400to2100"という1始まりでないディレクトリも存在する、など。

*15:正確には、TIFFで提供されるレタッチ画像(Adobe Lightroomで現像)と、LibRawで現像した画像サイズが一部のカメラモデルで一致しない。特にFujifilmCMOSセンサは独特なので要注意である。

GAS自作子ども向けカレンダーを3ヶ月使ってみた感想

概要

  • 年末にGASで作った子ども用カレンダー(下記ブログ参照)の経過報告エントリである。
  • 3ヶ月間使った結果、4歳の子どもが「一週間」「曜日」といった日にちの感覚を身につけ始めた。
  • カレンダーをアップデートし、オプションで (1) 漢字+ふりがな表記の選択と (2) 5種類のカラーバリエーションから選択できるようにした。

irohalog.hatenablog.com

自作の子ども用カレンダーを3ヶ月間使った感想

4月に入り、カレンダーも4枚目になった。 手書きを面倒に感じる性格なので、Googleカレンダーの最新予定を入れて印刷できるのは便利でよい。

何より、子どもに変化があったのが嬉しい。

子どもが「一週間」「曜日」といった日にちの感覚を身につけ始めた。これはいずれ体得するものではあるが、今の段階から明日はピアノのお稽古、その次の日は幼稚園お休み、といったようなことが理解されると親との会話が成り立ちやすく、助かるなという印象である。

また子どもが「自分が小さい時に、あんなことやこんなことがあったことを覚えているよ」というような感じで話し始めたのだ。カレンダーには毎月10枚程度の写真をコラージュして載せている。子どもがその写真を毎日眺めている中で、覚えているのか新たな記憶として作られているのか、これまでの楽しい思い出に再びふれているようである。

たかがカレンダーではあるが、自分は下記のような効果を狙っている。

  • 子どもが読めるカレンダーにすることで、子どもが自然と「ひと月」「一週間」「曜日」といった日にちの感覚を身につけられるようになる。
  • 二十四節気」「朔弦望」などの情報を入れることで、「季節」「月の満ち欠け」といった周期性を学べるようになる。
  • その月の思い出写真をカレンダーにのせることで、子どもが自分の成長や家族の愛を感じられるようになる。

まぁ、何より自己満足のためにしばらくトイレにこのカレンダーを貼って過ごすことにする。

カレンダーのアップデート

子どもの成長に合わせて、ひらがな表記だけでなくオプションで漢字併記も選べるようにした。 また、季節や気分に合わせて選択できるように、5種類のカラーバリエーションからカレンダーの色を選択できるようにした。

sample calendar
漢字併記+色変更のカレンダー(画像はフリー素材使用)

子ども向けカレンダーをGAS (Google Apps Script) で生成するプログラム

ProRawをRAW現像する

概要

  • 前エントリの続き irohalog.hatenablog.com

  • この記事を参考に、イメージングパイプラインを学ぶ目的でProRawデータ(DNGファイル)をC++でRAW現像する。

  • RAW現像結果から各処理の効果を確認するとともに、LibRaw及びMacPreviewでRAW現像した結果と自分のRAW現像結果とを比較する。

【目次】

コード

コードはこちら。

github.com

DNGファイルの読み込みには、LibRawを用いる。 データはNumPy風に多次元配列を扱える Xtensor で扱い、行列演算などの処理を行う。 処理した結果はOpenCV を用いて保存する。

Xtensor及びOpenCVを用いるのは、自分がすでに使える環境を持っていて慣れているからである。今回の処理だけを考えれば、もっと導入のしやすい選択があったかもしれない。

この記事に出てくるコードはGitHubに置いたコードから抜粋し、簡略化したものである。

ProRawデータをRAW現像する

今日はバレンタインデー🍫 ということで、今回は最終的にこんな感じになる美味しそうな(美味しかった)ケーキのProRawデータを処理していく。

ProRawの読み込み

LibRawを用いて、ProRaw(DNGファイル)を読み込む。

LibRaw raw;
int res = raw.open_file(input_filename.c_str());
res = raw.unpack();

ProRawの画像データをXtensorに格納する。

xt::xtensor<ushort, 2> image({4, (std::size_t)(raw.imgdata.sizes.iheight *
                                              raw.imgdata.sizes.iwidth)});
for (int i = 0; i < image.shape()[1]; i++) {
     xt::view(image, xt::all(), i) =
          xt::adapt(raw.imgdata.rawdata.color4_image[i], {4});
}
image = xt::view(image, xt::range(0, 3), xt::all());

今回はピクセル毎の処理しかしないため、画像の二次元構造を無視してChannel毎にデータを一列に並べて扱う。 この時点でimageのサイズは、(3 Channels, 全ピクセル数) であり、ChannelはRGB順に並ぶ。

下記は、処理前のProRawデータをそのまま8-bitに変換し、PNG形式で保存した画像である1。うっすらケーキが見えるかな。

“処理前のProRawデータ”

ブラックレベル補正

ProRawはブラックレベルが0であり、ブラックレベル補正は必要ない。ProRawとして保存する前にiPhoneで処理してくれていると思われる。

カラー補正

画像データをカメラ色空間からsRGB'へ変換する。この変換に用いる行列については、前回のエントリへ記載した。

// Conversion Matrix from Camera Native Color Spaxce to sRGB'
xt::linalg::dot(srgb_to_cam, image);

ここで、srgb_to_camはカメラ色空間からsRGB'へ変換する行列であり、LibRawのimgdata.color.rgb_camをXtensorに格納した行列である。

camera_to_sRGB()

輝度とコントラスト調整

シンプルなヒストグラムストレッチ (Histogram stretching) で画像全体の輝度とコントラストを調整する。 ピクセル全体から、輝度値の上位・下位から指定された割合(rate)を0あるいはMAX値へ移動させるようにRGB値を線形変換し、ヒストグラムを引き延ばす。

ヒストグラムストレッチによる輝度ヒストグラムの変化
([左]: ヒストグラムストレッチ未適用の場合、[右]: ヒストグラムストレッチ適用の場合 [rate=0.12] )

adjust_brightness()

なお、同様の処理はLibRawでno_auto_bright=falseで適用される(defaultで適用)。 raw.imgdata.params.no_auto_bright = 0;

ガンマ補正

sRGB'値にガンマ補正を適用し、非線形なsRGB値へと変換する。 次の式を線形なsRGB'値のそれぞれに適用し、非線形なsRGB値に変換する2

 \displaystyle
C_{sRGB} = \left\{
\begin{array}{ll}
12.92 C_{sRGB'} & (C_{sRGB'} \leq 0.0031308) \\
1.055 C_{sRGB'}^{1/2.4} - 0.055&  ( C_{sRGB'} > 0.0031308)
\end{array}
\right.

ここで、  C_{sRGB'} はsRGB'のR, G, Bを範囲[0, 1]で表した値であり、  C_{sRGB} はsRGBへ変換後のR, G, B値である。

コードはこんな感じ。

if ( image(ch, i) < 0.0031308 * USHRT_MAX) {
     image(ch, i) *= 12.92;
} else {
     float value = static_cast<float>(image(ch, i)) / USHRT_MAX;
     value = (std::pow(value, 1. / 2.4) * 1.055) - 0.055;
     image(ch, i) = value * USHRT_MAX;
}

gamma_correction()

RAW現像結果

様々なパタンでRAW現像した結果を比較し、各処理の効果を確認する。

C: カラー補正適用
A: 輝度とコントラスト調整適用 (特に指定がなければrate=0.12)
G: ガンマ補正適用

カラー補正の効果確認

カラー補正の有無で結果を比較する。

C+A+G A+G
“Result “Result

カラー補正を適用することで、色味(彩度)が本物に近づいている3

輝度とコントラスト調整の効果確認

輝度とコントラスト調整のrateを変化させた結果は下記の通り。

C+G
rate=0.00
C+A+G
rate=0.04
C+A+G
rate=0.08
C+A+G
rate=0.10
C+A+G
rate=0.12
Result image (threshold=0.00) Result image (threshold=0.04) Result image (threshold=0.08) Result image (threshold=0.10) Result image (threshold=0.12)

輝度とコントラスト調整を入れない場合には画像は全体的に暗いが、適用することで画像が明るくなっている。 また、輝度とコントラスト調整で引き伸ばしを大きくする(よりストレッチさせてヒストグラムを平坦にする)ことで、コントラストが高まっている。

ガンマ補正の効果確認

ガンマ補正の効果を確認する。

C+A+G C C+A C+G
“Result “Result “Result “Result

ガンマ補正も輝度とコントラスト調整も未適用の場合には画像が暗くなる(左から2番目の図)。

ガンマ補正なしで輝度とコントラスト調整を行なった場合には、コントラストが強調されすぎて不自然な印象である(左から3番目の図)。輝度とコントラスト調整のrateを下げることで不自然なコントラスト強調は抑制されるだろうが、その場合にはおそらく画像全体が暗くなっていくだろう(左から2番目の図に近づく)。

ややズレるが、カラー補正とガンマ補正だけの画像もここに載せておく(一番右の図)。左から2番目の図に比べて明るくなっているが、これだけでは何が写っているのか分かりにくい。

LibRaw及びMac Previewで現像した結果との比較

今回はProRawのイメージングパイプラインを学ぶことが目的であり、いい感じにRAW現像することには注力していないが4、LibRaw及びMacPreviewで現像した結果との比較をのせておく。

自分の結果は、全部入り「C+A+G」パタンで現像した結果である(rate=12%)。

LibRawは、raw.imgdata.params.use_camera_wb = 1;のみ指定し、他はデフォルトパラメータを用いて変換した結果である。これは、dcraw -c -wで変換した結果に相当する。LibRawはTIFF形式でしか保存できないため、LibRawの変換結果をOpenCVのcv::Matに入れてPNG形式で保存した5。 コードはこちら

Mac Previewは、DNGファイルをPreviewで開いて、8-bit PNG形式にエクスポートした画像である。

自分の結果 LibRaw Mac Preview
“Result “Result “Result
“Result “Result “Result
“Result “Result “Result
“Result “Result “Result

全体的に、Previewの変換結果はさすがだな...と思う。一行目、自分の結果はケーキのクリームが白飛びしているのに対し、Previewのクリームは質感がしっかり残っている。二行目、Previewの結果は黒つぶれを防ぎ、全体がわかるように調整されている。三行目、四行目、やはり自分の結果は白飛びが目立つのに対し、Previewの結果は色味(なんというか、彩度なんだけどそれだけじゃない)が絶妙にいいと思う。LibRawの結果も、白飛び・黒つぶれはあるものの、デフォルト設定の割に悪くないね。


RAWは8-bit画像になってしまっては中々取り戻せない鮮明な世界だ。いかにセンサーが捉えてくれた鮮明さをその後残し続けられるか、RAWの12-bit/16-bitの良さを活かすImage Processingに今後も期待したい。


  1. DNGファイルの画像データは12-bitで、これを16-bit(ushort)で格納する。PNG形式で保存する際に、一列に並べていたピクセル値を縦横の位置に戻し、16-bit(ushort)から8-bit(uchar)に変換して保存する。OpenCVはchannelの並びがBGRなので、channelの並びも変更することに注意する。
  2. この数式はパラメータも含めてIEC 61966-2-1:1999 で定めされている。
  3. 近づけるべき「本物」とは何か、あるところまでいくと難しい問題である。実物を見た記憶よりもむしろ、画像をディスプレイで表示し、それを目で見たときに素敵に見えるかどうかが重要かもしれない。実際のイチゴよりも赤いほうが好まれるのだ、多分。
  4. 言い訳です!
  5. 念のため同じ結果をLibRawでTIFF形式でも保存しており、OpenCVPNG形式で保存した画像と見た感じ同じであることを確認した。GitHubのdata内に双方の画像を保存している。

iPhone ProRawのカメラ色空間→sRGB' 変換行列を求める

概要

RAW画像処理に興味のある今日この頃。

下記の記事を参考にRAW現像をしてみたいが、まずはiPhoneで撮影したRAW(ProRaw)を、sRGB'(ガンマ補正前のsRGB)に変換する方法についてまとめておく。 uzusayuu.hatenadiary.jp


2023.2.6追記

ProRawを触ってみたい(プログラムで扱ってみたい)と思っているなら、下記の順にみておくことが最短ルートと思われる。 ネットで情報を漁っても自分はうまくこの情報に辿り着けなかったので、ここに記載しておく。

  1. Capture and process ProRAW images
    Apple が提供するProRawの概要を説明したDeveloper向け動画。英語字幕もTranscriptもあるので、倍速でみても概要はつかめると思う。

  2. DNG標準仕様書 v1.6
    DNGの標準仕様書に各メタデータの詳細が記載されていた。v1.6ではProRaw用に追加されたタグ(Semantic Masks)があり、説明されている。また、「Mapping Camera Color Space to CIE XYZ Space」節に、DNGのどの値をいかに使ってカメラ色空間→sRGBへの変換を実現するか、知りたいことは全て書いてあった。 helpx.adobe.com


【目次】

iPhoneでRAW画像を取得する方法

Appleのサイトを参考にiPhoneの標準カメラでRAW画像を撮影し、DNGファイルを取得することができた。 しかし、データはBayer配列ではなかった。
iPhone標準カメラ経由で取得されるProRawはBayer配列ではなく、すでにデモザイク後のデータになっているようだ1
ProRawファイルついては、下記の2つのブログが大変参考になった。

iPhoneでBayer配列を取得するなら、3rd Partyのアプリを使うしかないようだ2

デモザイキングも奥深いようだが3、今回はデモザイキング後のProRawファイルを使うことにする。

カメラ色空間からsRGB'への変換

LibRawではimgdata.color.rgb_camにsRGBへと変換する行列が格納されており、同じ値がrawpyではcolor_matrixに格納されている。 rgb_cam (ないしはcolor_matrix) とProRawのメタデータ (Exif) に格納されているColor Matrix 2 (0xc622) との関係性が知りたかったので、LibRaw内部の処理などを探ってみた。

まず、Exif Tag の解説4によれば、Color Matrix 2はXYZ表色系からカメラ色空間へと変換する行列のようだ。 ExifCalibration Illuminant 2がD65となっていることから、D65光源でのXYZ表色系からカメラ色空間への変換行列である。

Exif Toolの値(一部)
Exif Toolの値(一部)

XYZ表色系→カメラ色空間へと変換する行列cam_xyzExifColor Matrix 2を入れる(命名規則についてはこちら5)。

sRGB'(ガンマ補正していないsRGB)からXYZ表色系への変換行列を、XYZ表色系→カメラ色空間へと変換するColor Matrix 2にかけることで、sRGB'→カメラ色空間へと変換する行列cam_rgbが算出される。 この擬似逆行列rgb_camであり、カメラ色空間→sRGB'へと変換する行列となる。

rawpyで確認してみる。

まず、iPhoneで撮影したProRaw/DNGファイルから、rawpyのcolor_matrixの値を求める。

    raw = rawpy.imread("proraw_sample.DNG")
    print("Color Matrix:")
    print(raw.color_matrix)
Color Matrix:
[[ 1.3842709  -0.3266568  -0.05761418  0.        ]
 [-0.18470636  1.3894675  -0.20476116  0.        ]
 [ 0.03309519 -0.6075885   1.5744933   0.        ]]

これがLibRawのcam_rgbの値と同じであることは確認した。

次に、Exifデータからcam_rgb, color_matrixと同じ値を算出する。

XYZ表色系→カメラ色空間の変換行列をcam_xyzとして定義する。これとsRGB'→XYZ表色系への変換行列xyz_srgb内積をとりsRGB'→カメラ色空間の変換行列cam_srgbを求める。

    # XYZ to Camera Native Color Space Matrix (D65) in Exif
    cam_xyz = np.array([[0.9145434499, -0.3222275078, -0.1262248605], [-0.4288679957,
                                  1.309540987, 0.09467574954], [-0.1062918678, 0.2350454628, 0.4307328463]])

   # sRGB' to XYZ
    xyz_srgb = np.array([[0.4124564, 0.3575761, 0.1804375],
                         [0.2126729, 0.7151522, 0.0721750],
                         [0.0193339, 0.1191920, 0.9503041]])
    cam_srgb = np.dot(cam_xyz, xyz_srgb)
    print("sRGB' to Camera-Native-Color-Space Matrix:")
    print(cam_srgb)

cam_srgbの値

sRGB' to Camera-Native-Color-Space Matrix:
[[0.56836791 0.1513202  0.04047686]
 [0.10344498 0.79445276 0.107103  ]
 [0.03339166 0.41852832 0.93916176]]

cam_srgbを正規化して擬似逆行列を求めることでrgb_cam, color_matrix と同じものが求められた。

    # Normalize
    norm_cam_srgb = np.empty_like(cam_srgb)
    for r in range(0, 3):
        sum = np.sum(cam_srgb[r, :])
        if 0.00001 < sum:
            norm_cam_srgb[r, :] = cam_srgb[r, :] / sum
        else:
            norm_cam_srgb[r, :] = 0
    print("Normalized sRGB to Camera-Color-Space Matrix:")
    print(norm_cam_srgb)

    # Camera Native Color Space to sRGB'
    srgb_cam = np.linalg.inv(norm_cam_srgb)
    print("Camera-Native-Color-Space to sRGB matrix:")
    print(srgb_cam)
Camera-Native-Color-Space to sRGB matrix:
[[ 1.38427096 -0.32665678 -0.05761418]
 [-0.18470636  1.38946752 -0.20476116]
 [ 0.03309519 -0.60758851  1.57449332]]

ふむ、LibRawの値と同じである。

なお、LibRawの内部処理を追っていくと、ExifAnalog BalanceColor Matrix 2に掛けてcam_xyzを定めていた。 しかしながら、少なくとも今回のProRawの場合にはAnalog Balanceを掛けても掛けなくても正規化後には同じ値になる。

また、Analog BalanceExif 2.3 standardでは定義されいないが、ここを読む限り、(理想的には)RAW画像として保存される前に適用されたアナログゲインのようだ。

いずれにせよ、今回は無視してよさそうである。


  1. Appleのサイトを見てもProRawとしか言ってなくて、RAWがとれるとは言ってない。
  2. iPhoneのAPIにはBayer配列でRAW画像を取得したり、ProRawを取得したりするAPIが用意されている。ChatGPTにBayer配列のRAWが撮れるiPhoneアプリを聞いたところ、Lightroom、ProCamera、VSCO、Halideとのこと。
  3. Adobe Lightroomのデモザイク処理ではCNNを使っているようだ。この記事は興味深い。https://business.adobe.com/blog/the-latest/enhance-details
  4. https://exiv2.org/tags.html
  5. 変数をout_inの形で書くことにする。LibRawでもそのように命名している。これは最初混乱したが、行列の計算を考えると、左辺に出力、右辺に入力がくるからだと思われる。行列の計算順になっているのだと思えばLibRawの変数も理解しやすくなった。

子ども向けカレンダーをGASで作成してみた

概要

sample calendar

はじめに

2023年、あけましておめでとうございます。 昨年は緑豊かな地方へ引っ越したり、第二子を出産したりしました。 皆さまにとって新しい一年が素晴らしい年となりますよう、心よりお祈り申し上げます。

さて、昨年末に今年のカレンダーを自作しました。

夫婦間はGoogleカレンダーで予定を共有していますが、年少の子どもとはこれまで口頭で予定を共有していました。

しかしながら、子どもとうまく予定を共有できていないなと思うことが多々ありました。おそらく「今月」「来週」といった日付の概念がまだしっかり理解できていないからだと思います。

まだタブレットGoogleカレンダーを毎朝見てもらうというのも難しい(面倒な)お年頃なため、紙のカレンダーを使って予定を共有しながら日付の概念も学んでもらおうと思い立ちました。

こども向けカレンダーの要件

作成するカレンダーの要件として考えたものは下記です。

  • [MUST] その月に撮影した家族の写真を入れる(例:1月のカレンダーには過去の1月に撮影した家族の写真を入れる)
  • [MUST] 子どもが読めるように基本的にひらがなで書く(まだひらがなと数字しかしっかり読めないため)
  • [できれば] カレンダーから月の満ち欠け、行事、季節感などの学びがあるようにする
  • [できれば] 月曜始まり

なぜ自作するのか

まず初めに [MUST] の項目を満たすカレンダーサービスを探しました。最初に思いついたのは、みてね のカレンダー作成サービスです。これまでみてね にアップロードしてきた写真を使ってオリジナルカレンダーが作成できる1 のですが、残念ながら曜日などがひらがなで書かれたフォーマットがなく、幼児が日付を学びながらカレンダーを利用することは難しいと思いました2

他にも写真入りカレンダーを作成できるサービスや、手作りカレンダーの販売はいくつかありましたが、カレンダーフォーマットが子ども向けであるもの(特に曜日がひらながで書かれたもの)は見つけられませんでした。

写真を入れられるカレンダーは子ども向けフォーマットじゃないし、子ども向けカレンダーは写真が入れられない(PDFや紙のカレンダー商品を購入する感じでカスタマイズ不可)という状況でした。

ということで、[MUST] を満たす既成品が見つからなかったので、自作するしかありませんでした。

自作カレンダー

Googleスプレッドシートで作成しました。

手作業もトライしましたが12ヶ月分となるとあまりに面倒だったため、来年以降も使うかもしれないしと思ってGASでカレンダーのシートを生成しました。 カレンダー程度ならGASで細かいレイアウトまでコードでかけたので、シートへの手作業操作が不要で結果的に楽でした。

sample calendar pdf
カレンダー作成のイメージ

「祝日」「二十四節気」「朔弦望(月の満ち欠け)」などは国立天文台暦計算室 が公開しているGoogleカレンダーをインポートしました。 加えて、家族の誕生日などを家族用のGoogleカレンダーからインポートしています。

また、子ども向けカレンダーへ表示する用語変換ルールをスプレッドシートにあらかじめ記載して利用しています。

例:1日→ついたち、元日→がんじつ 🎍

各月に貼り付ける写真はGoogleドライブから取得しています(URL直接指定も可能)。

おわりに

最近子どもが一人でトイレに行くようになったので、トイレにカレンダーを貼って予定を共有しています。

12ヶ月分をいっきに作成して簡単に製本するもの良いかなと思いましたが、今のところは1ヶ月づつ作成して貼っていくつもりです。

カレンダーに入れる写真を毎月選ぶのがこれから楽しみです。

コードはこちら


  1. https://gift.mitene.us/blogs/column/campaign_20210702
  2. みてね は私の大好きなサービスです。プレミアム会員です!是非子どもが使えるカレンダーを作れるようにしていただきたいです。

C++DNNフレームワークの関数設計に関する考察あるいはポエム

概要

はじめに

久しぶりのブログ更新になってしまった。 最近は自作DNN (Deep Neural Network) フレームワークの実装が楽しい。この自作DNNフレームワークには、kuuという名前をつけた。中二っぽくていい!

具体的にどのようにDNNフレームワークで学習を行うかは、計算グラフの理論、DNNにおけるBackpropagationのアルゴリズムフレームワークに含まれる個々のアルゴリズムなどを知っているだけでは自明でない。 自分自身、以前はDNNフレームワークを使いながら特にBackpropagationの具体的な実現方法がブラックボックスとなっていたところがあり、自作DNNフレームワークを実装してみることにした。

ゼロから作るDeep Learning ❸ ―フレームワーク編やそのシリーズの存在は知っていたものの、kuuのドラフト版が出来上がって一通り動くようになってからようやく読んだ。 結果的には、作ってから読むことで本に書いてある考察や著者の悩みポイントにをより深く理解できたと思う。 またシリーズのキャッチコピー (?) である「作る経験はコピーできない。」「作るからこそ、見えるモノ。」という言葉にはちょっと感動してしまった。

自作DNNフレームワークを作ること自体は、自分のための勉強であり、世の中に貢献するような取り組みではないと思う。 しかし作った際の悩みポイントを紹介することで、既存DNNフレームワークの利用者がその設計に感じていた疑問や、気持ち悪さや、愚痴が多少減るのではないかと思ったのでブログへ投稿することにした。

例えば私は「なぜPyTorchもChainerもModule/Linkfunctional/FunctionNodeを分けるのだろう?Linearとかどっちにも定義しなければいけなくて面倒じゃない?」 という疑問を持っていたのだが、それはkuu設計の初期段階で解消された。 学習対象パラメータを保持しておくためのクラスと、forward/backwardで行う操作とを別で定義しないとどうしてもうまく設計できなかった。また役割が違うのだという事を理解すれば、むしろ分けることの方が自然に思えた。

ChainerVariableNodeFunctionNodeを相互に呼び出しあって作り出す計算グラフの設計は美しい。 ただそれはPythonだからきれいに実装できたところが大きいように思う 。ゼロから作るDeep Learning ❸ ―フレームワーク編で紹介されている設計もPythonベースである。 C++でDefine-by-Runを実装してみると、 型や相互参照の制約などでなかなか一筋縄には設計できない。

設計で悩んだ(悩んでいる)点は数多くあるが、この記事ではkuuでFunctionを設計していた際のメモを共有する。

Function

前提

DNNフレームワークの実装において、所与の変数を変形する操作をFunctionと呼ぶことにする。 Functionは、forward-pathで行う操作 (以下、forward) とbackward-pathで行う操作 (以下、backward) とがある。Functionの例として、例えばRelu, Linear, Batch Normalizationなどがある。

Function設計の考察

学習対象パラメータはFunctionとは別の機構 (PyTorchのModuleやChainerのLinkなど) で情報を保持している事とする。 この場合に、必要な変数を仮引数として全て受け取ってしまえば、forwardはstaticメソッドかあるいはクラスに属さない関数でよい。 一方、backwardはforwardで利用した実行時の情報をどこかに保存し、それらを用いて処理を行う必要がある。

Functionのforwardとbackwardの実現方法には、少なくとも下記2つのパターンがある。

  • [パターンA] クラスのメソッドとしてforwardとbackwardを定義する方法。例えば、ChainerのFunctionNodeはこのタイプである。

    f:id:Ytra:20200713093010p:plain
    パターンA

  • [パターンB] forwardとbackwardをバラバラに定義する方法。PyTorchやTensorFlowの (ライブラリ内で実装されている) Functionはこのタイプである。例えば、Pytorchでは名前空間torch::nn::functional配下でforward関数が一元的に定義され、名前空間torch::autograd::generated配下でforward関数に対応するbackwardクラスが定義されている。

    f:id:Ytra:20200711153901p:plain
    パターンB

パターンA

下記はパターンAでFunctionを実装する例である。

struct Func1{
    Tensor static forward(Tensor x);
    void backward(Tensor grad);
    Tensor operator() (Tensor x) {
        return forward(x);
    }
};

struct Func2{
    Tensor static forward(Tensor x);
    void backward(Tensor grad);
    Tensor operator() (Tensor x) {
        return forward(x);
    }
};

struct Func3{
    Tensor static forward(Tensor x);
    void backward(Tensor grad);
    Tensor operator() (Tensor x) {
        return forward(x);
    }
};

この場合にforwardを呼びだす方法は

auto y = Func1::forward(Func2::forward(Func3::forward(x)));

あるいは

auto y = Func1{}(Func2{}(Func3{}(x)));

となる。

Pros.

  • パターンBと比較して、forwardとbackwardの対応関係を把握しやすく、新しいFunctionを実装しやすい。

Cons.

  • パターンBと比較して、forwardメソッドを呼ぶ方法では記述量が多く可読性が低くなるし、operator()で呼ぶ方法ではわざわざインスタンスを生成しなければならない。[注1]

[注1] Cons.を解決するためにoperator()をクラスから直接呼び出せるようにstatic化したくなる。同じ要望を持っている人がいて、c++標準化で提案されているようだ

パターンB

下記はパターンBでFunctionを実装する例である。

namespace forward {
Tensor func1(Tensor x);
Tensor func2(Tensor x);
Tensor func3(Tensor x);
}

namespace backward{
struct Func1Backward(Tensor grad){
    void apply(Tensor grad);
};
struct Func2Backward(Tensor grad){
    void apply(Tensor grad);
};
struct Func3Backward(Tensor grad){
    void apply(Tensor grad);
};
}

この場合にforwardの呼び出しは下記のようになる。

namespace F = forward;
auto y = F::func1(F::func2(F::func3(x)));

Pros.

  • FunctionとしてのRelu, Linear, Batch Normalizationなどを関数として利用できるのは自然で可読性も高い (個人感)。

Cons.

  • forwardとbackwardの関連を人手で書かないといけないため、backward呼び出し時に間違えないよう注意が必要。

Chainerでの実装例 (パターンA)

簡潔に呼び出せない問題を解決するために、Chainerでは下記のように、FunctionNode Classのforwardを呼び出す関数を別途定義している。

class ReLU(function_node.FunctionNode):
    ...
    def forward_cpu(self, inputs):
    ...
    def backward(self, indexes, grad_outputs):
    ...

def relu(x):
    y, = ReLU().apply((x,))
    return y

TensorFlowの実装例 (パターンB)

backwardだけでなくforwardもstructで定義され、operator()で呼びだす (つまりインスタンスから呼びだす必要がある) 。

PyTorchでの実装例 (パターンB)

[注2] /autograd/generated/Functions.hはPyTorchにより自動生成されるコードのため、本家リポジトリではない場所を参照する。

PyTorch の場合は、Atenで実装されている関数はderivatives.yaml にforwardとbackwardの紐付けを記載している。このderivatives.yamltemplates配下の定義 (特にこれこれ) を利用して、自動生成スクリプトのなかで autograd/generated/Function.hautograd/generated/Function.cpp を生成する。

例えばtorch::nn::functional::relu()に対応するbackwardは、autograd/generated/Function.hstruct ReluBackward0struct ReluBackward1として定義される。 そして、torch::nn::functional::relu() (forward) が呼び出されると、下記のようにReluBackward0を計算グラフに登録する。

torch/autograd/generated/VariableType_0.cpp

Tensor relu(const Tensor & self) {
  RECORD_FUNCTION("relu", std::vector<c10::IValue>({self}), Node::peek_at_next_sequence_nr());
  auto& self_ = unpack(self, "self", 0);
  std::shared_ptr<ReluBackward0> grad_fn;
  if (compute_requires_grad( self )) {
    grad_fn = std::shared_ptr<ReluBackward0>(new ReluBackward0(), deleteNode);
    grad_fn->set_next_edges(collect_next_edges( self ));
    grad_fn->self_ = SavedVariable(self, false);
  }
  torch::jit::Node* node = nullptr;
  std::shared_ptr<jit::tracer::TracingState> tracer_state;
  if (jit::tracer::isTracing()) {
    tracer_state = jit::tracer::getTracingState();
    at::Symbol op_name;
    op_name = jit::Symbol::fromQualString("aten::relu");
    node = tracer_state->graph->create(op_name, /*num_outputs=*/0);
    jit::tracer::recordSourceLocation(node);
    jit::tracer::addInputs(node, "self", self);
    tracer_state->graph->insertNode(node);
  
    jit::tracer::setTracingState(nullptr);
  }
  #ifndef NDEBUG
  c10::optional<Storage> self__storage_saved =
    self_.has_storage() ? c10::optional<Storage>(self_.storage()) : c10::nullopt;
  c10::intrusive_ptr<TensorImpl> self__impl_saved;
  if (self_.defined()) self__impl_saved = self_.getIntrusivePtr();
  #endif
  auto tmp = ([&]() {
    at::AutoNonVariableTypeMode non_var_type_mode(true);
    return at::relu(self_);
  })();
  auto result = std::move(tmp);
  #ifndef NDEBUG
  if (self__storage_saved.has_value())
    AT_ASSERT(self__storage_saved.value().is_alias_of(self_.storage()));
  if (self__impl_saved) AT_ASSERT(self__impl_saved == self_.getIntrusivePtr());
  #endif
  if (grad_fn) {
      set_history(flatten_tensor_args( result ), grad_fn);
  }
  if (tracer_state) {
    jit::tracer::setTracingState(std::move(tracer_state));
    jit::tracer::addOutput(node, result);
  }
  return result;
}

kuuでの選択

kuuにおいてパターンA とパターンBのどちらを選ぶかを考えた際、呼び出し方が自然な書き方だと思えたパターンBを選択すことにした。 実は最初はパターンAで実装したのだが、呼び出しを簡潔に書けない点がどうしても気になった。PyTorchのC++APIならkuuよりずっときれいに呼び出せていたのが、そもそもこの考察をしたきっかけである。

ゼロから作るDeep Learning ❸ ―フレームワーク編

ゼロから作るDeep Learning ❸ ―フレームワーク編

  • 作者:斎藤 康毅
  • 発売日: 2020/04/20
  • メディア: 単行本(ソフトカバー)