diff --git a/xstream/src/java/com/thoughtworks/xstream/converters/extended/RecordConverter.java b/xstream/src/java/com/thoughtworks/xstream/converters/extended/RecordConverter.java index 71d085625..8431fd47e 100644 --- a/xstream/src/java/com/thoughtworks/xstream/converters/extended/RecordConverter.java +++ b/xstream/src/java/com/thoughtworks/xstream/converters/extended/RecordConverter.java @@ -153,7 +153,7 @@ private void writeItem(RecordComponent recordComponent, Object compValue, Marsha */ @Override public Object unmarshal(HierarchicalStreamReader reader, UnmarshallingContext context) { - final Class aRecord = findClass(reader); + final Class aRecord = context.getRequiredType(); if (!isRecord(aRecord)) { throw new ConversionException(aRecord + " is not a record"); } @@ -225,12 +225,6 @@ private static Class classForName(String className) { } } - private static Class findClass(HierarchicalStreamReader reader) { - String c = reader.getAttribute("class"); - String className = c != null ? c : reader.getNodeName(); - return classForName(className); - } - /** * Invokes the canonical constructor of a record class with the given argument values. */ diff --git a/xstream/src/test/com/thoughtworks/xstream/converters/extended/RecordConverterTest.java b/xstream/src/test/com/thoughtworks/xstream/converters/extended/RecordConverterTest.java index a057dbbab..81ce18bcb 100644 --- a/xstream/src/test/com/thoughtworks/xstream/converters/extended/RecordConverterTest.java +++ b/xstream/src/test/com/thoughtworks/xstream/converters/extended/RecordConverterTest.java @@ -14,8 +14,17 @@ import com.thoughtworks.acceptance.AbstractAcceptanceTest; import com.thoughtworks.xstream.XStream; import com.thoughtworks.xstream.converters.ConversionException; - +import com.thoughtworks.xstream.io.xml.StaxDriver; +import com.thoughtworks.xstream.security.AnyTypePermission; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.Serializable; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.Objects; import java.util.stream.IntStream; +import junit.framework.Assert; /** @@ -296,4 +305,72 @@ static T expectThrows(Class throwableClass, Runnable ta return throwableClass.cast(cause); } } + + public static final class MyObj + { + SerializableRecord field1 = new SerializableRecord( 1, 2 ); + Object field2 = new SerializableRecord( 1, 2 ); + Object field3 = new Object[] { new SerializableRecord( 1, 2 ) }; + + NotSerializableRecord field4 = new NotSerializableRecord( 1, 2 ); + Object field5 = new NotSerializableRecord( 1, 2 ); + Object field6 = new Object[] { new NotSerializableRecord( 1, 2 ) }; + + @Override + public boolean equals(Object o) + { + if ( this == o ) + { + return true; + } + if ( o == null || getClass() != o.getClass() ) + { + return false; + } + final MyObj myObj = (MyObj) o; + return Objects.equals( field1, myObj.field1 ) && + Objects.equals( field2, myObj.field2 ) && + Arrays.equals( (Object[]) field3, (Object[]) myObj.field3 ) && + Objects.equals( field4, myObj.field4 ) && + Objects.equals( field5, myObj.field5 ) && + Arrays.equals( (Object[]) field6, (Object[]) myObj.field6 ); + } + + @Override + public int hashCode() + { + return Objects.hash( field1, field2, field3, field4, field5, field6 ); + } + } + + public record NotSerializableRecord(int a, int b) {} + public record SerializableRecord(int a, int b) implements Serializable {} + + public void testMissingClassAttributeDoesNotCauseCrash() + { + final MyObj in = new MyObj(); + + final XStream xstream = new XStream( new StaxDriver() ); + xstream.addPermission( AnyTypePermission.ANY); + + final ByteArrayOutputStream bos = new ByteArrayOutputStream(); + xstream.toXML( in, bos ); + final String expected = """ + +12 +12 +12 +1 +2 +12 +12 + + + """; + final String actual = bos.toString( StandardCharsets.UTF_8 ); + Assert.assertEquals( expected.replace("\n", ""), actual ); + + final Object out = xstream.fromXML( new ByteArrayInputStream( bos.toByteArray() ) ); + Assert.assertEquals( in, out ); + } }