89 lines
2.6 KiB
Python
89 lines
2.6 KiB
Python
import threading
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
|
|
|
|
class MemUsageMonitor(threading.Thread):
|
|
run_flag = None
|
|
device = None
|
|
disabled = False
|
|
opts = None
|
|
data = None
|
|
|
|
def __init__(self, name, device, opts):
|
|
threading.Thread.__init__(self)
|
|
self.name = name
|
|
self.device = device
|
|
self.opts = opts
|
|
|
|
self.daemon = True
|
|
self.run_flag = threading.Event()
|
|
self.data = defaultdict(int)
|
|
|
|
try:
|
|
torch.cuda.mem_get_info()
|
|
torch.cuda.memory_stats(self.device)
|
|
except Exception as e: # AMD or whatever
|
|
print(f"Warning: caught exception '{e}', memory monitor disabled")
|
|
self.disabled = True
|
|
|
|
def run(self):
|
|
if self.disabled:
|
|
return
|
|
|
|
while True:
|
|
self.run_flag.wait()
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
self.data.clear()
|
|
|
|
if self.opts.memmon_poll_rate <= 0:
|
|
self.run_flag.clear()
|
|
continue
|
|
|
|
self.data["min_free"] = torch.cuda.mem_get_info()[0]
|
|
|
|
while self.run_flag.is_set():
|
|
free, total = torch.cuda.mem_get_info() # calling with self.device errors, torch bug?
|
|
self.data["min_free"] = min(self.data["min_free"], free)
|
|
|
|
time.sleep(1 / self.opts.memmon_poll_rate)
|
|
|
|
def dump_debug(self):
|
|
print(self, 'recorded data:')
|
|
for k, v in self.read().items():
|
|
print(k, -(v // -(1024 ** 2)))
|
|
|
|
print(self, 'raw torch memory stats:')
|
|
tm = torch.cuda.memory_stats(self.device)
|
|
for k, v in tm.items():
|
|
if 'bytes' not in k:
|
|
continue
|
|
print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2)))
|
|
|
|
print(torch.cuda.memory_summary())
|
|
|
|
def monitor(self):
|
|
self.run_flag.set()
|
|
|
|
def read(self):
|
|
if not self.disabled:
|
|
free, total = torch.cuda.mem_get_info()
|
|
self.data["free"] = free
|
|
self.data["total"] = total
|
|
|
|
torch_stats = torch.cuda.memory_stats(self.device)
|
|
self.data["active"] = torch_stats["active.all.current"]
|
|
self.data["active_peak"] = torch_stats["active_bytes.all.peak"]
|
|
self.data["reserved"] = torch_stats["reserved_bytes.all.current"]
|
|
self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"]
|
|
self.data["system_peak"] = total - self.data["min_free"]
|
|
|
|
return self.data
|
|
|
|
def stop(self):
|
|
self.run_flag.clear()
|
|
return self.read()
|