본문 바로가기
Artificial Intelligence/Neural Networks

[ CVPR2022 / GML4VC ] 5. Pytorch Geometric 이란 무엇인가?

by SuperMemi 2022. 8. 24.
반응형

Graph Machine Learning for Visual Computing (GML4VC) Tutorial 

 


CVPR 2022 에서 Graph Machine Learning 에 대한 튜토리얼(tutorial)을 진행했습니다.
이에 대해 요약 정리 하는 시리즈 글입니다.


[이전 글]

2022.08.22 - [AI/Graph Neural Networks] - [ CVPR2022 / GML4VC ] 1. 개요 (Graph Machine Learning, GNNs)

2022.08.22 - [AI/Graph Neural Networks] - [ CVPR2022 / GML4VC ] 2. Open Remarks

2022.08.22 - [AI/Graph Neural Networks] - [ CVPR2022 / GML4VC ] 3. Geometric Deep Learning (Invariant, Equivariant)

2022.08.24 - [AI/Graph Neural Networks] - [ CVPR2022 / GML4VC ] 4. Graph Neural Networks(GNNS) 기본 개념 정리

Graph Neural Networks(GNNS) 기본 개념 정리에 이어지는 내용입니다.


Building GNNs with Pytorch Geometric

 

matthias_talk.mp4

 

drive.google.com

 

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric


PyTorch Geometric

  • PyG(PyTorch Geometric) : 그래프포인트 클라우드 그리고 매니폴드에 대한 딥러닝을 가능하게 하는 PyTorch 라이브러리입니다.
  • Graph Neural Networks 들이 구현되어 있으며 쉽게 사용합니다.
  • GPU 의 고속 활용이 가능합니다. 
  • 유연성과 활용성으로 쉽게 산업이나 학업에 적용가능합니다.

 

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric


Design Principles

기본적으로 PyG 는 PyTorch 와 거의 유사한 구조를 가지고 있습니다. 

 

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric


Message Passing Graph Neural Networks

PyG 에서의 그래프는 sparse graph 로 정의합니다. Sparse graph 란 연결된 edge(Non-zero)만 정보로 저장하는 Graph 형식입니다. 

 

  • graph G = ( input node feature matrix H, (edge indices I, edge features E 
  • input node feature matrix : 각 노드의 특징 정보를 행렬로 나타낸 것 입니다. Initial node representation 으로 이해할 수 있습니다.
  • edge indicies : 방향성을 가지며, start node 와 end node의 index 로 구성됩니다. 
  • edge features E : 엣지의 특징을 나타내는 행렬로서, 선택적으로 존재합니다. 즉, 없을 수도 있습니다.  

 

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric

 

PyG 에서는 gather scatter 연산자의 병렬화로 구성된 유연한 구현을 통해 MessagePassing 이 동작합니다. 

 

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric

 

PyG Code with MessagePassing

  1. MessagePassing 을 상속 받습니다.
  2. 초기화 하면서 사전에 정의된 aggregation scheme 을 정의합니다.
  3. self.propagate 를 호출하면서 message passing 이 진행됩니다. 그리고 propagate 함수 내부적으로 message 함수가 호출됩니다.
  4. *_j*_i 를 변수명에 붙임으로 써, Node-level 표현들이 자동적으로 edge-level 표현으로 전환됩니다.

 

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric

class EdgeConv(MessagePassing):

    def __init__(self, in_dim, out_dim):
        super().__init__(aggr="max")
        self.mlp = MLP(U2 * in_dim, out_dim)
        
    def forward(self, x: Tensor, edge_index: Tensor):
        return self.propagate(edge_index, x=x)
    
    def message(self, x_j: Tensor, x_i = Tensor):
        edge_features = torch.cat([x_i, x_j - x_i], dim=1)
        return self.mlp(edge_features)

 

Aggregations Formats

Aggregation Format 에는 다양한 종류가 있습니다. 각각의 장단점이 존재하여 경우에 맞게 잘 사용하는 것이 중요합니다.

 

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric

 

Aggregation Algorithms

직접 구현도 가능하고, nn.aggr 에서 구현된 것을 사용할 수 있습니다.

각 방식들은 위의 format 중에 가장 효율적인 것으로 선택됩니다.

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric


Mini-Batching

 

  • PyG 는 작은 그래프들로 이루어진 mini batch 를 가능하게 합니다.
  • Graph 가 sparse form 이기 때문에 메모리 걱정이 덜합니다.
  • 다양한 사이즈의 그래프들을 처리가능합니다.

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric

  • 다양한 사이즈에 대응하기 위해서
  • PyG 에서는 batch dimension 과 node dimension 을 하나의 single dimension 으로 병합합니다.
  • 다른 방법으로는 nested_tensor 방식이 존재합니다.

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric


Example

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric

  • GeometricShapes Dataset : 40 개
    • pos : 3차원 좌표 32개
    • face : mesh의 삼각형 표현 (각 삼각형의 꼭짓점 좌표를 pos의 index 값으로 표현)
    • y : category class
  • Point Cloud Sampling 또는 Dynamic Graph Generation 으로 표현 가능
import torch_geometric.transforms as T
from torch_geometric.datasets import GeometricShapes

dataset = GeometricShapes(root='data/GeometricShapes')
data = dataset[0]

print(data)
print(data.pos)
print(data.face)
print(data.y)
print(data.face.min()) # 0
print(data.face.max()) # 31

# Point Cloud Sampling:
dataset.transform = T.SamplePoints(num=256)
print(dataset[0])

# Dynamic Graph Generation
dataset.transform = T.Compose([T.SamplePoints(num=256), T.KNNGraph(k=6)])
print(dataset[0])

 

  • Network 학습 과정

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric


Additional Highlights

 

from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric
from CVPR 2022 Tutorial:Building GNNs with Pytorch Geometric

 


[ 다음 글 ]

2022.08.25 - [AI/Graph Neural Networks] - [ CVPR2022 / GML4VC ] 6. Deep GNNs (심층 그래프 신경망) 기본 개념 정리

2022.08.25 - [AI/Graph Neural Networks] - [ CVPR2022 / GML4VC ] 7. 비디오 이해를 위한 GNN 응용 (Graph ML for Video Understanding)


 

반응형