diff --git a/source/ast/test/transformTest.ml b/source/ast/test/transformTest.ml index a2119689311..bfbceb90672 100644 --- a/source/ast/test/transformTest.ml +++ b/source/ast/test/transformTest.ml @@ -153,6 +153,36 @@ let test_transform _ = orelse = [+Statement.Expression (+Expression.Constant (Constant.Integer 5))]; }; ] + 0; + assert_modifying_source + [ + +Statement.Match + { + Match.subject = +Expression.Constant (Constant.Integer 0); + cases = + [ + { + Match.Case.pattern = +Match.Pattern.MatchSingleton (Constant.Integer 2); + guard = Some (+Expression.Constant (Constant.Integer 4)); + body = []; + }; + ]; + }; + ] + [ + +Statement.Match + { + Match.subject = +Expression.Constant (Constant.Integer 1); + cases = + [ + { + Match.Case.pattern = +Match.Pattern.MatchSingleton (Constant.Integer 3); + guard = Some (+Expression.Constant (Constant.Integer 5)); + body = []; + }; + ]; + }; + ] 0 @@ -552,6 +582,9 @@ let test_statement_transformer _ = y = 6 class C: z = 7 + match x: + case 1: + w = 0 |} {| def foo(): @@ -568,6 +601,9 @@ let test_statement_transformer _ = y = 7 class C: z = 8 + match x: + case 1: + w = 1 |} 28 diff --git a/source/ast/transform.ml b/source/ast/transform.ml index 77da8e9f669..c37ad27160e 100644 --- a/source/ast/transform.ml +++ b/source/ast/transform.ml @@ -299,8 +299,55 @@ module Make (Transformer : Transformer) = struct orelse = transform_list orelse ~f:transform_statement |> List.concat; } | Import _ -> value - (* TODO(T107108689): Transform for match statement. *) - | Match _ -> value + | Match { Match.subject; cases } -> + let rec transform_pattern { Node.value; location } = + let value = + match value with + | Match.Pattern.MatchAs { pattern; name } -> + Match.Pattern.MatchAs { pattern = pattern >>| transform_pattern; name } + | MatchClass { class_name; patterns; keyword_attributes; keyword_patterns } -> + MatchClass + { + class_name; + patterns = transform_list patterns ~f:transform_pattern; + keyword_attributes; + keyword_patterns = transform_list keyword_patterns ~f:transform_pattern; + } + | MatchMapping { keys; patterns; rest } -> + MatchMapping + { + keys = transform_list keys ~f:transform_expression; + patterns = transform_list patterns ~f:transform_pattern; + rest; + } + | MatchOr patterns -> MatchOr (transform_list patterns ~f:transform_pattern) + | MatchSequence patterns -> + MatchSequence (transform_list patterns ~f:transform_pattern) + | MatchSingleton constant -> ( + let expression = + transform_expression { Node.value = Expression.Constant constant; location } + in + match expression.value with + | Expression.Constant constant -> MatchSingleton constant + | _ -> MatchValue expression) + | MatchStar maybe_identifier -> MatchStar maybe_identifier + | MatchValue expression -> MatchValue expression + | MatchWildcard -> MatchWildcard + in + { Node.value; location } + in + let transform_case { Match.Case.pattern; guard; body } = + { + Match.Case.pattern = transform_pattern pattern; + guard = guard >>| transform_expression; + body = transform_list body ~f:transform_statement |> List.concat; + } + in + Match + { + Match.subject = transform_expression subject; + cases = transform_list cases ~f:transform_case; + } | Nonlocal _ -> value | Pass -> value | Raise { Raise.expression; from } -> @@ -385,8 +432,6 @@ module MakeStatementTransformer (Transformer : StatementTransformer) = struct | Expression _ | Global _ | Import _ - (* TODO(T107108689): Transform for match statement. *) - | Match _ | Pass | Raise _ | Return _ @@ -406,6 +451,11 @@ module MakeStatementTransformer (Transformer : StatementTransformer) = struct let body = List.concat_map ~f:transform_statement body in let orelse = List.concat_map ~f:transform_statement orelse in If { value with If.body; orelse } + | Match ({ Match.cases; _ } as value) -> + let transform_case ({ Match.Case.body; _ } as value) = + { value with Match.Case.body = List.concat_map ~f:transform_statement body } + in + Match { value with Match.cases = List.map ~f:transform_case cases } | While ({ While.body; orelse; _ } as value) -> let body = List.concat_map ~f:transform_statement body in let orelse = List.concat_map ~f:transform_statement orelse in