diff --git a/metals/src/main/scala/scala/meta/internal/metals/codeactions/ExtractRenameMember.scala b/metals/src/main/scala/scala/meta/internal/metals/codeactions/ExtractRenameMember.scala index 171a5b23dc6..885e79edb07 100644 --- a/metals/src/main/scala/scala/meta/internal/metals/codeactions/ExtractRenameMember.scala +++ b/metals/src/main/scala/scala/meta/internal/metals/codeactions/ExtractRenameMember.scala @@ -14,6 +14,7 @@ import scala.meta.Template import scala.meta.Term import scala.meta.Tree import scala.meta.Type +import scala.meta.inputs.Position import scala.meta.internal.metals.Buffers import scala.meta.internal.metals.ClientCommands import scala.meta.internal.metals.MetalsEnrichments._ @@ -114,9 +115,7 @@ class ExtractRenameMember( case _ => val codeActionOpt = for { defn <- defnAtCursor - if canExtractDefn( - defn.member - ) + if canExtractDefn(defn.member) memberType <- getMemberType(defn.member) title = ExtractRenameMember.title( memberType, @@ -143,12 +142,19 @@ class ExtractRenameMember( case t: Defn.Trait => nodes += EndableMember(t, None) case o: Defn.Object => nodes += EndableMember(o, None) case e: Defn.Enum => nodes += EndableMember(e, None) + case t: Defn.Type if t.mods.exists { + case Mod.Opaque() => true + case _ => false + } => + nodes += EndableMember(t, None) case endMarker: Term.EndMarker => - if (nodes.size > 0) { - val last = nodes.remove(nodes.size - 1) - nodes += EndableMember(last.member, Some(endMarker)) + nodes.lastOption match { + case Some(last) + if last.maybeEndMarker.isEmpty && endMarker.name.value == last.member.name.value => + nodes.remove(nodes.size - 1) + nodes += EndableMember(last.member, Some(endMarker)) + case _ => } - case s: Source => super.apply(s) case _ => @@ -159,21 +165,59 @@ class ExtractRenameMember( nodes.toList } - case class Comments(text: String, startPos: l.Position) + private def findExtensions( + tree: Tree + ): List[EndableDefn[Defn.ExtensionGroup]] = { + val nodes: ListBuffer[EndableDefn[Defn.ExtensionGroup]] = ListBuffer() - case class EndableMember( - member: Member, + val traverser = new SimpleTraverser { + override def apply(tree: Tree): Unit = tree match { + case p: Pkg => super.apply(p) + case s: Source => super.apply(s) + case e: Defn.ExtensionGroup => + nodes += EndableDefn[Defn.ExtensionGroup](e, None) + case endMarker: Term.EndMarker if endMarker.name.value == "extension" => + nodes.lastOption match { + case Some(last) if last.maybeEndMarker.isEmpty => + nodes.remove(nodes.size - 1) + nodes += EndableDefn[Defn.ExtensionGroup]( + last.member, + Some(endMarker), + ) + case _ => + } + case _ => + } + } + traverser(tree) + + nodes.toList + } + + case class Comments(text: String, startPos: Int) + + type EndableMember = EndableDefn[Member] + object EndableMember { + def apply( + member: Member, + maybeEndMarker: Option[Term.EndMarker], + ): EndableMember = + EndableDefn(member, maybeEndMarker) + } + case class EndableDefn[T <: Tree]( + member: T, maybeEndMarker: Option[Term.EndMarker], commentsAbove: Option[Comments] = None, ) { - def withComments(comments: Comments): EndableMember = + def withComments(comments: Comments): EndableDefn[T] = this.copy(commentsAbove = Some(comments)) - def memberPos: l.Range = { - val pos = member.pos.toLsp - commentsAbove.foreach(comments => pos.setStart(comments.startPos)) - pos - } - def endMarkerPos: Option[l.Range] = maybeEndMarker.map(_.pos.toLsp) + def memberPos: Position = + commentsAbove match { + case None => member.pos + case Some(comment) => + Position.Range(member.pos.input, comment.startPos, member.pos.end) + } + def endMarkerPos: Option[Position] = maybeEndMarker.map(_.pos) } private def isSealed(t: Tree): Boolean = t match { @@ -191,7 +235,6 @@ class ExtractRenameMember( case o: Defn.Object => o.name.value :: completePreName(o) case po: Pkg.Object => po.name.value :: completePreName(po) case tmpl: Template => completePreName(tmpl) - case _: Source => Nil case _ => Nil } case None => Nil @@ -229,6 +272,7 @@ class ExtractRenameMember( range: l.Range, endableMember: EndableMember, maybeCompanionEndableMember: Option[EndableMember], + extensions: List[EndableDefn[Defn.ExtensionGroup]], ): (String, Int) = { // List of sequential packages or imports before the member definition val packages: ListBuffer[Pkg] = ListBuffer() @@ -271,23 +315,27 @@ class ExtractRenameMember( val pkg: Option[Pkg] = mergedTermsOpt.map(t => Pkg(ref = t, stats = Nil)) - def marker(endableMember: EndableMember) = endableMember.maybeEndMarker - .map(endMarker => "\n" + endMarker.toString()) - .getOrElse("") - - val structure = pkg.toList.mkString("\n") :: - imports.mkString("\n") :: - endableMember.commentsAbove.map(_.text).getOrElse("") + - endableMember.member.toString + marker(endableMember) :: - maybeCompanionEndableMember - .flatMap(_.commentsAbove) - .map(_.text) - .getOrElse("") + - maybeCompanionEndableMember - .map(_.member.toString) - .getOrElse("") + maybeCompanionEndableMember - .map(marker) - .getOrElse("") :: Nil + def memberParts[T <: Tree](member: EndableDefn[T]) = { + val endMarker = + member.maybeEndMarker + .map(endMarker => "\n" + endMarker.toString()) + .getOrElse("") + member.commentsAbove.map(_.text).getOrElse("") + + member.member.toString + endMarker + } + + val definitionsParts = maybeCompanionEndableMember match { + case None => List(memberParts(endableMember)) + case Some(companion) + if (companion.memberPos.start < endableMember.memberPos.start) => + List(memberParts(companion), memberParts(endableMember)) + case Some(companion) => + List(memberParts(endableMember), memberParts(companion)) + } + + val structure = + pkg.toList.mkString("\n") :: imports.mkString("\n") :: + definitionsParts ++ extensions.map(memberParts) val preDefinitionLines = pkg.toList.length + imports.length val defnLine = @@ -322,6 +370,7 @@ class ExtractRenameMember( case t: Defn.Trait => namesFromTemplate(t.templ) case o: Defn.Object => namesFromTemplate(o.templ) case e: Defn.Enum => namesFromTemplate(e.templ) + case _ => Nil } } @@ -419,6 +468,7 @@ class ExtractRenameMember( val opt = for { tree <- trees.get(path) + text <- buffers.get(path) definitions = membersDefinitions(tree) memberDefn <- definitions .find( @@ -428,11 +478,23 @@ class ExtractRenameMember( companion = definitions .find(isCompanion(memberDefn.member)) .map(withComment) + extensions = findExtensions(tree).filter { e => + e.member.paramClauses.toList match { + case Term.ParamClause(param :: Nil, _) :: Nil => + param.decltpe match { + case Some(Type.Name(name)) => + name == memberDefn.member.name.value + case _ => false + } + case _ => false + } + } (fileContent, defnLine) = newFileContent( tree, range, memberDefn, companion, + extensions, ) newFilePath = newPathFromClass(uri, memberDefn.member) if !newFilePath.exists @@ -444,6 +506,8 @@ class ExtractRenameMember( fileContent, memberDefn, companion, + extensions, + text, ) val newFileMemberRange = new l.Range() val pos = new l.Position(defnLine, 0) @@ -471,18 +535,26 @@ class ExtractRenameMember( override def kind: String = l.CodeActionKind.RefactorExtract + private def whiteChars = Set('\r', '\n', ' ', '\t') + private def extractClassCommand( newUri: String, content: String, endableMember: EndableMember, maybeEndableMemberCompanion: Option[EndableMember], + extensions: List[EndableDefn[Defn.ExtensionGroup]], + fileText: String, ): List[l.TextEdit] = { val newPath = newUri.toAbsolutePath newPath.writeText(content) - def removeEdits(range: l.Range): List[l.TextEdit] = - List(new l.TextEdit(range, "")) + def removeEdit(pos: Position): l.TextEdit = new l.TextEdit(pos.toLsp, "") + + def removesPositionsForMember[T <: Tree]( + member: EndableDefn[T] + ): List[Position] = + member.memberPos :: member.endMarkerPos.toList val packageEdit = endableMember.member.parent .flatMap { @@ -495,17 +567,47 @@ class ExtractRenameMember( Some(p) case _ => None } - .map(tree => removeEdits(tree.pos.toLsp)) - - packageEdit.getOrElse( - removeEdits(endableMember.memberPos) ++ - (maybeEndableMemberCompanion - .map(_.memberPos) - ++ endableMember.endMarkerPos - ++ maybeEndableMemberCompanion - .flatMap(_.endMarkerPos)).flatMap(removeEdits) - ) + .map(tree => List(removeEdit(tree.pos))) + + // if there are only white chars between remove edits, we merge them + def mergeEdits( + edits: List[Position], + acc: List[Position], + ): List[Position] = { + edits match { + case edit1 :: edit2 :: rest => + def onlyWhiteCharsBetween = + fileText.slice(edit1.end, edit2.start).forall(whiteChars) + if (edit1.end >= edit2.start || onlyWhiteCharsBetween) { + val merged = Position.Range(edit1.input, edit1.start, edit2.end) + mergeEdits(merged :: rest, acc) + } else { + mergeEdits(edit2 :: rest, edit1 :: acc) + } + case edit :: Nil => + val followingWhites = + fileText.splitAt(edit.end)._2.takeWhile(whiteChars).size + Position.Range( + edit.input, + edit.start, + edit.end + followingWhites, + ) :: acc + case Nil => acc + } + } + + def membersRemove: List[l.TextEdit] = { + val positions = + removesPositionsForMember(endableMember) ++ + maybeEndableMemberCompanion.toList.flatMap( + removesPositionsForMember + ) ++ + extensions + .flatMap(removesPositionsForMember) + mergeEdits(positions.sortBy(_.start), Nil).map(removeEdit) + } + packageEdit.getOrElse(membersRemove) } private def findCommentsAbove(path: AbsolutePath, member: Member) = { @@ -524,7 +626,7 @@ class ExtractRenameMember( .map(_.pos) } yield Comments( part.splitAt(commentPos.start)._2, - commentPos.toLsp.getStart(), + commentPos.start, ) } @@ -538,6 +640,7 @@ object ExtractRenameMember { case _: Defn.Enum => "enum" case _: Defn.Trait => "trait" case _: Defn.Object => "object" + case _: Defn.Type => "opaque type" } def title(memberType: String, name: String): String = diff --git a/tests/slow/src/test/scala/tests/feature/Scala3CodeActionLspSuite.scala b/tests/slow/src/test/scala/tests/feature/Scala3CodeActionLspSuite.scala index bf8c63c727a..81a60e10e2c 100644 --- a/tests/slow/src/test/scala/tests/feature/Scala3CodeActionLspSuite.scala +++ b/tests/slow/src/test/scala/tests/feature/Scala3CodeActionLspSuite.scala @@ -600,6 +600,93 @@ class Scala3CodeActionLspSuite |""".stripMargin, ) + checkExtractedMember( + "opaque-type", + """|final class Bar() + | + |opaque type <> = String + | + |object Foo: + | val x = 1 + |""".stripMargin, + expectedActions = ExtractRenameMember.title("opaque type", "Foo"), + """| + |final class Bar() + | + |""".stripMargin, + newFile = ( + "Foo.scala", + s"""opaque type Foo = String + | + |object Foo: + | val x = 1 + |""".stripMargin, + ), + ) + + checkExtractedMember( + "opaque-type-companion", + """|final class Bar() + | + |opaque type Foo = String + | + |object <>: + | val x = 1 + |""".stripMargin, + expectedActions = ExtractRenameMember.title("object", "Foo"), + """| + |final class Bar() + | + |""".stripMargin, + newFile = ( + "Foo.scala", + s"""opaque type Foo = String + | + |object Foo: + | val x = 1 + |""".stripMargin, + ), + ) + + checkExtractedMember( + "opaque-type-extensions", + """|final class Bar() + | + |opaque type <> = String + | + |object Foo: + | val x = 1 + | + |extension (f: Foo) + | def u = 3 + |end extension + | + |extension (i: Int) + | def r = 1 + |end extension + |""".stripMargin, + expectedActions = ExtractRenameMember.title("opaque type", "Foo"), + """| + |final class Bar() + | + |extension (i: Int) + | def r = 1 + |end extension + |""".stripMargin, + newFile = ( + "Foo.scala", + s"""opaque type Foo = String + | + |object Foo: + | val x = 1 + | + |extension (f: Foo) + | def u = 3 + |end extension + |""".stripMargin, + ), + ) + private def getPath(name: String) = s"a/src/main/scala/a/$name" def checkExtractedMember(