Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] High CUDA memory usage with ONNX Runtime and inconsistent memory release #22297

Open
niyathimariya opened this issue Oct 3, 2024 · 2 comments
Labels
ep:CUDA issues related to the CUDA execution provider memory performance issues related to performance regressions

Comments

@niyathimariya
Copy link

niyathimariya commented Oct 3, 2024

After converting my PyTorch model to ONNX format, I noticed an issue with CUDA memory management. When processing a large input, the CUDA memory usage spikes as expected. However, for subsequent smaller inputs, the memory usage does not decrease, and the high CUDA memory allocation persists.

To mitigate this, I attempted to configure the ONNX Runtime session options as follows:

session_options.add_session_config_entry("memory.enable_memory_arena_shrinkage", "1")  
session_options.add_session_config_entry('cudnn_conv_use_max_workspace', '0')

While these settings increased the inference time, there was no improvement in CUDA memory usage, and the memory was not released as expected after processing smaller inputs.

Expected Behavior:
CUDA memory usage should decrease after processing smaller inputs, releasing the previously allocated memory.

Observed Behavior:
CUDA memory is not released when model encounters a lower input sequence after processing a higher input sequence, whereas the PyTorch model releases memory in such cases.
Picture1

Is there any way to optimize CUDA memory usage in ONNX Runtime for this case?

To reproduce

def load_model(self, model_name):
    session_options = ort.SessionOptions()
    session_options.add_session_config_entry('cudnn_conv_use_max_workspace', '0')
    session_options.add_session_config_entry("memory.enable_memory_arena_shrinkage", "1") 
    session_options.enable_mem_pattern = False
    session_options.arena_extend_strategy = "kSameAsRequested"
    session_options.gpu_mem_limit = 10 * 1024 * 1024 * 1024
    
    model = ORTModelForSeq2SeqLM.from_pretrained(model_name, provider="CUDAExecutionProvider", session_options=session_options)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return tokenizer, model

def inference(self, batch, doc_id='-1'):
    responses, status = '', False
    try:
        encodings = self.tokenizer(batch, padding=True, truncation=True, max_length=8192, return_tensors="pt").to(self.device)
        with torch.no_grad():
            generated_ids = self.model.generate(
                encodings.input_ids,
                max_new_tokens=1024
            )
            responses = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            status = True  
    except Exception as e:
        logger.error(f"Failed to do inference on LLM, error: {e}")

    torch.cuda.empty_cache()
    return status, responses

Urgency

Platform

Linux

OS Version

Ubuntu 22.04.4 LTS

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.19.2

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

cuda_12.1.r12.1

Model File

No response

Is this a quantized model?

No

@niyathimariya niyathimariya added the performance issues related to performance regressions label Oct 3, 2024
@github-actions github-actions bot added the ep:CUDA issues related to the CUDA execution provider label Oct 3, 2024
@niyathimariya
Copy link
Author

@pranavsharma @hariharans29, is this normal behavior of ONNX Runtime?

@hariharans29
Copy link
Member

Hi @niyathimariya -

Arena shrinkage is a run option and not a session option as configured in the sample code (As an aside, I also think cudnn_conv_use_max_workspace is a cuda provider option and not a session option). Please see this -

ro2.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu:0;gpu:0")
. So, I am not sure how this setting increased the inference time as these settings should not have had any effect.

Generally speaking, I recommend using the arena shrinkage feature with some other settings. By default, model weights are configured to be allocated through the memory arena and the arena could have grown quite a bit by simply allocating weights and runtime shrinkage will have no effect on allocations that are currently "in use" (i.e.) weight related allocations. So, it is best to allocate weights through non-arena memory and only use the arena for runtime allocations and configure the arena to have a high enough initial chunk for the arena to service "most requests" and to have the shrinkage logic de-allocate chunks pertaining to the "large requests". Hope this makes sense. Please see detailed comment here - #9509 (comment).

I am not quite sure what arena configuration settings are currently avilable through the python API. Please take a look at relevant code - a good starting point will be the test code here -

ort_arena_cfg = onnxrt.OrtArenaCfg(8, 0, 4, 2)
. Contributions are wlcome if you would like to enhance the allocator config functionality via the Python API.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:CUDA issues related to the CUDA execution provider memory performance issues related to performance regressions
Projects
None yet
Development

No branches or pull requests

3 participants