[GH-ISSUE #18] [Feature Request] torch_geometric support #18

Closed
opened 2026-05-05 03:22:05 -06:00 by gitea-mirror · 5 comments
Owner

Originally created by @plutonium-239 on GitHub (May 27, 2022).
Original GitHub issue: https://github.com/XuehaiPan/nvitop/issues/18

Originally assigned to: @XuehaiPan on GitHub.

First of all, thank you for the excellent nvitop.

I want to know if you have plans to add an integration with PyTorch Geometric (pyg)? It is a really great library for GNNs. I don't know if its helpful at all but it also has some profiling functions in the torch_geometric.profile module.
Since pytorch lightning doesn't give you granular control over your models (sometimes reqd in research) I haven't seen anyone use it. On the flipside, pytorch geometric is probably the most popular library for GNNs.

Hope you consider this!

Originally created by @plutonium-239 on GitHub (May 27, 2022). Original GitHub issue: https://github.com/XuehaiPan/nvitop/issues/18 Originally assigned to: @XuehaiPan on GitHub. First of all, thank you for the excellent `nvitop`. I want to know if you have plans to add an integration with [`PyTorch Geometric (pyg)`](https://www.pyg.org/)? It is a really great library for GNNs. I don't know if its helpful at all but it also has some profiling functions in the `torch_geometric.profile` module. Since pytorch lightning doesn't give you granular control over your models (sometimes reqd in research) I haven't seen anyone use it. On the flipside, pytorch geometric is probably the most popular library for GNNs. Hope you consider this!
gitea-mirror 2026-05-05 03:22:05 -06:00
Author
Owner

@XuehaiPan commented on GitHub (May 30, 2022):

I want to know if you have plans to add an integration with PyTorch Geometric (pyg)? It is a really great library for GNNs. I don't know if its helpful at all but it also has some profiling functions in the torch_geometric.profile module.

Hi @plutonium-239, thanks for the suggestion. I'll be investigating in these two weeks.

Since pytorch lightning doesn't give you granular control over your models (sometimes reqd in research)

nvitop.core provides low-level APIs that users can integrate it into their training/testing code. nvitop.callbacks offers framework high-level APIs.

I'm thinking of adding a more customizable mid-level API, e.g., a background daemon gathering the process status (both on host and GPUs). Let the users log the useful items into tensorboard.SummaryWriter, csv, or print to stdout.

<!-- gh-comment-id:1140715085 --> @XuehaiPan commented on GitHub (May 30, 2022): > I want to know if you have plans to add an integration with [`PyTorch Geometric (pyg)`](https://www.pyg.org/)? It is a really great library for GNNs. I don't know if its helpful at all but it also has some profiling functions in the `torch_geometric.profile` module. Hi @plutonium-239, thanks for the suggestion. I'll be investigating in these two weeks. > Since pytorch lightning doesn't give you granular control over your models (sometimes reqd in research) `nvitop.core` provides low-level APIs that users can integrate it into their training/testing code. `nvitop.callbacks` offers framework high-level APIs. I'm thinking of adding a more customizable mid-level API, e.g., a background daemon gathering the process status (both on host and GPUs). Let the users log the useful items into `tensorboard.SummaryWriter`, `csv`, or print to `stdout`.
Author
Owner

@XuehaiPan commented on GitHub (Jun 22, 2022):

Hi @plutonium-239, I have looked into the source code of PyTorch Geometric (pyg). It looks that pyg supports PyTorch Lightning and the callback in nvitop is also usable.

For flexibility, I implement a new metric collector in PR #21, which allows the user has full control of the code logic of DL training.

For example:

>>> import os
>>> os.environ['CUDA_VISIBLE_DEVICES'] = '3,2,1,0'

>>> from nvitop import ResourceMetricCollector, Device, CudaDevice

>>> collector = ResourceMetricCollector()                          # log all devices and children processes on the GPUs of the current process
>>> collector = ResourceMetricCollector(root_pids={1})             # log all devices and all GPU processes
>>> collector = ResourceMetricCollector(devices=CudaDevice.all())  # use the CUDA ordinal

>>> with collector(tag='<tag>'):
...     # do something
...     collector.collect()  # -> Dict[str, float]
# key -> '<tag>/<scope>/<metric (unit)>/<mean/min/max>'
{
    '<tag>/host/cpu_percent (%)/mean': 8.967849777683456,
    '<tag>/host/cpu_percent (%)/min': 6.1,
    '<tag>/host/cpu_percent (%)/max': 28.1,
    ...,
    '<tag>/host/memory_percent (%)/mean': 21.5,
    '<tag>/host/swap_percent (%)/mean': 0.3,
    '<tag>/host/memory_used (GiB)/mean': 91.0136418208109,
    '<tag>/host/load_average (%) (1 min)/mean': 10.251427386878328,
    '<tag>/host/load_average (%) (5 min)/mean': 10.072539414569503,
    '<tag>/host/load_average (%) (15 min)/mean': 11.91126970422139,
    ...,
    '<tag>/cuda:0 (gpu:3)/memory_used (MiB)/mean': 3.875,
    '<tag>/cuda:0 (gpu:3)/memory_free (MiB)/mean': 11015.562499999998,
    '<tag>/cuda:0 (gpu:3)/memory_total (MiB)/mean': 11019.437500000002,
    '<tag>/cuda:0 (gpu:3)/memory_percent (%)/mean': 0.0,
    '<tag>/cuda:0 (gpu:3)/gpu_utilization (%)/mean': 0.0,
    '<tag>/cuda:0 (gpu:3)/memory_utilization (%)/mean': 0.0,
    '<tag>/cuda:0 (gpu:3)/fan_speed (%)/mean': 22.0,
    '<tag>/cuda:0 (gpu:3)/temperature (C)/mean': 25.0,
    '<tag>/cuda:0 (gpu:3)/power_usage (W)/mean': 19.11166264116916,
    ...,
    '<tag>/cuda:1 (gpu:2)/memory_used (MiB)/mean': 8878.875,
    ...,
    '<tag>/cuda:2 (gpu:1)/memory_used (MiB)/mean': 8182.875,
    ...,
    '<tag>/cuda:3 (gpu:0)/memory_used (MiB)/mean': 9286.875,
    ...,
    '<tag>/pid:12345/host/cpu_percent (%)/mean': 151.34342772112265,
    '<tag>/pid:12345/host/host_memory (MiB)/mean': 44749.72373447514,
    '<tag>/pid:12345/host/host_memory_percent (%)/mean': 8.675082352111717,
    '<tag>/pid:12345/host/running_time (min)': 336.23803206741576,
    '<tag>/pid:12345/cuda:1 (gpu:4)/gpu_memory (MiB)/mean': 8861.0,
    '<tag>/pid:12345/cuda:1 (gpu:4)/gpu_memory_percent (%)/mean': 80.4,
    '<tag>/pid:12345/cuda:1 (gpu:4)/gpu_memory_utilization (%)/mean': 6.711118172407917,
    '<tag>/pid:12345/cuda:1 (gpu:4)/gpu_sm_utilization (%)/mean': 48.23283397736476,
    ...,
    '<tag>/duration (s)': 7.247399162035435,
    '<tag>/timestamp': 1655909466.9981883
}

The results can be easily logged into TensorBoard or to CSV file.

import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from nvitop import CudaDevice, ResourceMetricCollector
from nvitop.callbacks.tensorboard import add_scalar_dict

# Build networks and prepare datasets
...

# Logger and status collector
writer = SummaryWriter()
collector = ResourceMetricCollector(devices=CudaDevice.all(),  # log all visible CUDA devices and use the CUDA ordinal
                                    root_pids={os.getpid()},   # only log the children processes of the current process
                                    interval=1.0)              # snapshot interval for background daemon thread

# Start training
global_step = 0
for epoch in range(num_epoch):
    with collector(tag='train'):
        for batch in train_dataset:
            with collector(tag='batch'):
                metrics = train(net, batch)
                global_step += 1
                add_scalar_dict(writer, 'train', metrics, global_step=global_step)
                add_scalar_dict(writer, 'resources',      # tag='resources/train/batch/...'
                                collector.collect(),
                                global_step=global_step)

        add_scalar_dict(writer, 'resources',              # tag='resources/train/...'
                        collector.collect(),
                        global_step=epoch)

    with collector(tag='validate'):
        metrics = validate(net, validation_dataset)
        add_scalar_dict(writer, 'validate', metrics, global_step=epoch)
        add_scalar_dict(writer, 'resources',              # tag='resources/validate/...'
                        collector.collect(),
                        global_step=epoch)
<!-- gh-comment-id:1163195495 --> @XuehaiPan commented on GitHub (Jun 22, 2022): Hi @plutonium-239, I have looked into the source code of [`PyTorch Geometric (pyg)`](https://www.pyg.org/). It looks that `pyg` supports PyTorch Lightning and the callback in `nvitop` is also usable. For flexibility, I implement a new metric collector in PR #21, which allows the user has full control of the code logic of DL training. For example: ```python >>> import os >>> os.environ['CUDA_VISIBLE_DEVICES'] = '3,2,1,0' >>> from nvitop import ResourceMetricCollector, Device, CudaDevice >>> collector = ResourceMetricCollector() # log all devices and children processes on the GPUs of the current process >>> collector = ResourceMetricCollector(root_pids={1}) # log all devices and all GPU processes >>> collector = ResourceMetricCollector(devices=CudaDevice.all()) # use the CUDA ordinal >>> with collector(tag='<tag>'): ... # do something ... collector.collect() # -> Dict[str, float] # key -> '<tag>/<scope>/<metric (unit)>/<mean/min/max>' { '<tag>/host/cpu_percent (%)/mean': 8.967849777683456, '<tag>/host/cpu_percent (%)/min': 6.1, '<tag>/host/cpu_percent (%)/max': 28.1, ..., '<tag>/host/memory_percent (%)/mean': 21.5, '<tag>/host/swap_percent (%)/mean': 0.3, '<tag>/host/memory_used (GiB)/mean': 91.0136418208109, '<tag>/host/load_average (%) (1 min)/mean': 10.251427386878328, '<tag>/host/load_average (%) (5 min)/mean': 10.072539414569503, '<tag>/host/load_average (%) (15 min)/mean': 11.91126970422139, ..., '<tag>/cuda:0 (gpu:3)/memory_used (MiB)/mean': 3.875, '<tag>/cuda:0 (gpu:3)/memory_free (MiB)/mean': 11015.562499999998, '<tag>/cuda:0 (gpu:3)/memory_total (MiB)/mean': 11019.437500000002, '<tag>/cuda:0 (gpu:3)/memory_percent (%)/mean': 0.0, '<tag>/cuda:0 (gpu:3)/gpu_utilization (%)/mean': 0.0, '<tag>/cuda:0 (gpu:3)/memory_utilization (%)/mean': 0.0, '<tag>/cuda:0 (gpu:3)/fan_speed (%)/mean': 22.0, '<tag>/cuda:0 (gpu:3)/temperature (C)/mean': 25.0, '<tag>/cuda:0 (gpu:3)/power_usage (W)/mean': 19.11166264116916, ..., '<tag>/cuda:1 (gpu:2)/memory_used (MiB)/mean': 8878.875, ..., '<tag>/cuda:2 (gpu:1)/memory_used (MiB)/mean': 8182.875, ..., '<tag>/cuda:3 (gpu:0)/memory_used (MiB)/mean': 9286.875, ..., '<tag>/pid:12345/host/cpu_percent (%)/mean': 151.34342772112265, '<tag>/pid:12345/host/host_memory (MiB)/mean': 44749.72373447514, '<tag>/pid:12345/host/host_memory_percent (%)/mean': 8.675082352111717, '<tag>/pid:12345/host/running_time (min)': 336.23803206741576, '<tag>/pid:12345/cuda:1 (gpu:4)/gpu_memory (MiB)/mean': 8861.0, '<tag>/pid:12345/cuda:1 (gpu:4)/gpu_memory_percent (%)/mean': 80.4, '<tag>/pid:12345/cuda:1 (gpu:4)/gpu_memory_utilization (%)/mean': 6.711118172407917, '<tag>/pid:12345/cuda:1 (gpu:4)/gpu_sm_utilization (%)/mean': 48.23283397736476, ..., '<tag>/duration (s)': 7.247399162035435, '<tag>/timestamp': 1655909466.9981883 } ``` The results can be easily logged into [TensorBoard](https://github.com/tensorflow/tensorboard) or to CSV file. ```python import os import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter from nvitop import CudaDevice, ResourceMetricCollector from nvitop.callbacks.tensorboard import add_scalar_dict # Build networks and prepare datasets ... # Logger and status collector writer = SummaryWriter() collector = ResourceMetricCollector(devices=CudaDevice.all(), # log all visible CUDA devices and use the CUDA ordinal root_pids={os.getpid()}, # only log the children processes of the current process interval=1.0) # snapshot interval for background daemon thread # Start training global_step = 0 for epoch in range(num_epoch): with collector(tag='train'): for batch in train_dataset: with collector(tag='batch'): metrics = train(net, batch) global_step += 1 add_scalar_dict(writer, 'train', metrics, global_step=global_step) add_scalar_dict(writer, 'resources', # tag='resources/train/batch/...' collector.collect(), global_step=global_step) add_scalar_dict(writer, 'resources', # tag='resources/train/...' collector.collect(), global_step=epoch) with collector(tag='validate'): metrics = validate(net, validation_dataset) add_scalar_dict(writer, 'validate', metrics, global_step=epoch) add_scalar_dict(writer, 'resources', # tag='resources/validate/...' collector.collect(), global_step=epoch) ```
Author
Owner

@XuehaiPan commented on GitHub (Jun 22, 2022):

Since the new feature is built on top of branch mig-support, which has not been released yet. To install:

pip3 install git+https://github.com/XuehaiPan/nvitop.git@collector#egg=nvitop

Any feedback is welcome.

<!-- gh-comment-id:1163198922 --> @XuehaiPan commented on GitHub (Jun 22, 2022): Since the new feature is built on top of branch `mig-support`, which has not been released yet. To install: ```bash pip3 install git+https://github.com/XuehaiPan/nvitop.git@collector#egg=nvitop ``` Any feedback is welcome.
Author
Owner

@plutonium-239 commented on GitHub (Jun 24, 2022):

This is awesome!
Thanks so much!

<!-- gh-comment-id:1165342733 --> @plutonium-239 commented on GitHub (Jun 24, 2022): This is awesome! Thanks so much!
Author
Owner

@XuehaiPan commented on GitHub (Jun 26, 2022):

Close as resolved by PR #21.

<!-- gh-comment-id:1166509940 --> @XuehaiPan commented on GitHub (Jun 26, 2022): Close as resolved by PR #21.
Sign in to join this conversation.
No milestone
No project
No assignees
1 participant
Notifications
Due date
The due date is invalid or out of range. Please use the format "yyyy-mm-dd".

No due date set.

Dependencies

No dependencies set.

Reference: github-starred/nvitop#18
No description provided.