Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate to new API for openai >= 1.0.0 #3

Merged
merged 4 commits into from
Nov 17, 2023

Conversation

grit-app[bot]
Copy link
Contributor

@grit-app grit-app bot commented Nov 17, 2023

✅ This migration is up to date! ✅

⚠️ This pull request was auto-generated with Grit. ⚠️

This pull request was created with these settings:

  • Target branch: main
  • Source files: **/*
  • Preset pattern: openai – Convert OpenAI from openai version to the v1 version.
Pattern body
engine marzano(0.1)
language python

pattern rename_resource() {
    or {
        `Audio` => `audio`,
        `ChatCompletion` => `chat.completions`,
        `Completion` => `completions`,
        `Edit` => `edits`,
        `Embedding` => `embeddings`,
        `File` => `files`,
        `FineTune` => `fine_tunes`,
        `FineTuningJob` => `fine_tuning`,
        `Image` => `images`,
        `Model` => `models`,
        `Moderation` => `moderations`,
    }
}

pattern rename_resource_cls() {
    or {
        r"Audio" => `resources.Audio`,
        r"ChatCompletion" => `resources.chat.Completions`,
        r"Completion" => `resources.Completions`,
        r"Edit" => `resources.Edits`,
        r"Embedding" => `resources.Embeddings`,
        r"File" => `resources.Files`,
        r"FineTune" => `resources.FineTunes`,
        r"FineTuningJob" => `resources.FineTuning`,
        r"Image" => `resources.Images`,
        r"Model" => `resources.Models`,
        r"Moderation" => `resources.Moderations`,
    }
}

pattern deprecated_resource() {
    or {
        `Customer`,
        `Deployment`,
        `Engine`,
        `ErrorObject`,
    }
}

pattern deprecated_resource_cls() {
    or {
        r"Customer",
        r"Deployment",
        r"Engine",
        r"ErrorObject",
    }
}


pattern rename_func($has_sync, $has_async, $res, $stmt, $params, $client) {
    $func where {
        if ($func <: r"a([a-zA-Z0-9]+)"($func_rest)) {
            $has_async = `true`,
            $func => $func_rest,
            if ($client <: undefined) {
                $stmt => `aclient.$res.$func($params)`,
            } else {
                $stmt => `$client.$res.$func($params)`,
            }
        } else {
            $has_sync = `true`,
            if ($client <: undefined) {
                $stmt => `client.$res.$func($params)`,
            } else {
                $stmt => `$client.$res.$func($params)`,
            }
        },
        // Fix function renames
        if ($res <: `Image`) {
          $func => `generate`
        }
    }
}

pattern change_import($has_sync, $has_async, $need_openai_import, $azure, $client_params) {
    $stmt where {
        $imports_and_defs = [],

        if ($need_openai_import <:  `true`) {
            $imports_and_defs += `import openai`,
        },

        if ($azure <: true) {
          $client = `AzureOpenAI`,
          $aclient = `AsyncAzureOpenAI`,
        } else {
          $client = `OpenAI`,
          $aclient = `AsyncOpenAI`,
        },

        $formatted_params = join(list = $client_params, separator = `,\n`),

        if (and { $has_sync <: `true`, $has_async <: `true` }) {
            $imports_and_defs += `from openai import $client, $aclient`,
            $imports_and_defs += ``, // Blank line
            $imports_and_defs += `client = $client($formatted_params)`,
            $imports_and_defs += `aclient = $aclient($formatted_params)`,
        } else if ($has_sync <: `true`) {
            $imports_and_defs += `from openai import $client`,
            $imports_and_defs += ``, // Blank line
            $imports_and_defs += `client = $client($formatted_params)`,
        } else if ($has_async <: `true`) {
            $imports_and_defs += `from openai import $aclient`,
            $imports_and_defs += ``, // Blank line
            $imports_and_defs += `aclient = $aclient($formatted_params)`,
        },

        $formatted = join(list = $imports_and_defs, separator=`\n`),
        $stmt => `$formatted`,
    }
}

pattern rewrite_whole_fn_call($import, $has_sync, $has_async, $res, $func, $params, $stmt, $body, $client, $azure) {
    or {
        rename_resource() where {
            $import = `true`,
            $func <: rename_func($has_sync, $has_async, $res, $stmt, $params, $client),
            if ($azure <: true) {
              $params <: maybe contains bubble `engine` => `model`
            }
        },
        deprecated_resource() as $dep_res where {
            $stmt_whole = $stmt,
            if ($body <: contains `$_ = $stmt` as $line) {
                $stmt_whole = $line,
            },
            $stmt_whole => todo(message=`The resource '$dep_res' has been deprecated`, target=$stmt_whole),
        }
    }
}

pattern unittest_patch() {
    or {
        decorated_definition($decorators, definition=$_) where {
            $decorators <: contains bubble decorator(value=`patch($cls_path)`) as $stmt where {
                $cls_path <: contains r"openai\.([a-zA-Z0-9]+)(?:.[^,]+)?"($res),
                if ($res <: rename_resource_cls()) {} else {
                    $res <: deprecated_resource_cls(),
                    $stmt => todo(message=`The resource '$res' has been deprecated`, target=$stmt),
                }
            }
        },
        function_definition($body) where {
            $body <: contains bubble($body) or {
                `patch.object($params)`,
                `patch($params)`,
            } as $stmt where {
                $params <: contains bubble($body, $stmt) r"openai\.([a-zA-Z0-9]+)(?:.[^,]+)?"($res) where or {
                    $res <: rename_resource_cls(),
                    and {
                        $res <: deprecated_resource_cls(),
                        $line = $stmt,
                        if ($body <: contains or { `with $stmt:`, `with $stmt as $_:` } as $l) {
                            $line = $l,
                        },
                        $line => todo(message=`The resource '$res' has been deprecated`, target=$line),
                    }
                }
            },
        }
    }
}

pattern pytest_patch() {
    decorated_definition($decorators, $definition) where {
        $decorators <: contains decorator(value=`pytest.fixture`),
        $definition <: bubble function_definition($body, $parameters) where {
            $parameters <: [$monkeypatch, ...],
            $body <: contains bubble($monkeypatch) or {
                `$monkeypatch.setattr($params)` as $stmt where {
                    $params <: contains bubble($stmt) r"openai\.([a-zA-Z0-9]+)(?:.[^,]+)?"($res) where or {
                        $res <: rename_resource_cls(),
                        $stmt => todo(message=`The resource '$res' has been deprecated`, target=$stmt),
                    }
                },
                `monkeypatch.delattr($params)` as $stmt where {
                    $params <: contains bubble($stmt) r"openai\.([a-zA-Z0-9]+)(?:.[^,]+)?"($res) where or {
                        $res <: rename_resource_cls(),
                        $stmt => todo(message=`The resource '$res' has been deprecated`, target=$stmt),
                    }
                },
            }
        },
    },
}

pattern openai_main($client, $azure) {
    $body where {
        if ($client <: undefined) {
            $need_openai_import = `false`,
            $create_client = true,
        } else {
            $need_openai_import = `true`,
            $create_client = false,
        },
        if ($azure <: undefined) {
          $azure = false,
        },
        $has_openai_import = `false`,
        $has_partial_import = `false`,
        $has_sync = `false`,
        $has_async = `false`,

        $client_params = [],

        $body <: any {
          // Mark all the places where we they configure openai as something that requires manual intervention
          if ($client <: undefined) {
            contains bubble($need_openai_import, $azure, $client_params) `openai.$field = $val` as $setter where {
              $field <: or {
                `api_type` where {
                  $res = .,
                  if ($val <: or {`"azure"`, `"azure_ad"`}) {
                    $azure = true
                  },
                },
                `api_base` where {
                  $azure <: true,
                  $client_params += `azure_endpoint=$val`,
                  $res = .,
                },
                `api_key` where {
                  $res = .,
                  $client_params += `api_key=$val`,
                },
                `api_version` where {
                  $res = .,
                  // Only Azure has api_version
                  $azure = true,
                  $client_params += `api_version=$val`,
                },
                $_ where {
                  $res = todo(message=`The 'openai.$field' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI($field=$val)'`, target=$setter),
                  $need_openai_import = `true`,
                }
              }
            } => $res
          },
          // Remap errors
          contains `openai.error.$exp` => `openai.$exp` where {
            $need_openai_import = `true`,
          },
          contains `import openai` as $import_stmt where {
              $body <: contains bubble($has_sync, $has_async, $has_openai_import, $body, $client, $azure) `openai.$res.$func($params)` as $stmt where {
                  $res <: rewrite_whole_fn_call(import = $has_openai_import, $has_sync, $has_async, $res, $func, $params, $stmt, $body, $client, $azure),
              },
          },
          contains `from openai import $resources` as $partial_import_stmt where {
            $has_partial_import = `true`,
            $body <: contains bubble($has_sync, $has_async, $resources, $client, $azure) `$res.$func($params)` as $stmt where {
                $resources <: contains $res,
                $res <: rewrite_whole_fn_call($import, $has_sync, $has_async, $res, $func, $params, $stmt, $body, $client, $azure),
            }
          },
          contains unittest_patch(),
          contains pytest_patch(),
        },

        if ($create_client <: true) {
            if ($has_openai_import <: `true`) {
                $import_stmt <: change_import($has_sync, $has_async, $need_openai_import, $azure, $client_params),
                if ($has_partial_import <: `true`) {
                    $partial_import_stmt => .,
                },
            } else if ($has_partial_import <: `true`) {
                $partial_import_stmt <: change_import($has_sync, $has_async, $need_openai_import, $azure, $client_params),
            },
        },
    }
}

file($body) where {
  // No client means instantiate one per file
  $body <: openai_main()
}

Please feel free to provide feedback on this pull request. Any comments will be incorporated into future migrations.

@phelps-sg phelps-sg changed the title [bot] Run grit migration: Apply a GritQL pattern Migrate to new API for openai >= 1.0.0 Nov 17, 2023
@phelps-sg phelps-sg merged commit 28a4228 into main Nov 17, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant