From 9cbc3f2b544cbc5b51ce423b65c93ff39316f779 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Proch=C3=A1zka?= Date: Wed, 20 May 2026 12:01:47 +0200 Subject: [PATCH 1/2] refactor(arrow/avro): migrate from hamba/avro to twmb/avro MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Switch the avro reader from hamba/avro to twmb/avro for schema parsing and value decoding. The new library hands back typed Go values (e.g. *big.Rat for decimals, [16]byte for fixed-uuid) instead of []byte, which shortens the data-side helpers and lets us fix a stack of bugs the old []byte arms were masking. The exported arrow/avro surface is preserved: NewOCFReader, OCFReader and its methods, the Option/With… helpers, and ArrowSchemaFromAvro all keep their signatures. A new ArrowSchemaFromAvroJSON is added alongside as the recommended entry point — it doesn't couple callers to a particular Avro library through its type signature. ArrowSchemaFromAvro is now marked Deprecated and serializes the hamba schema via json.Marshal before re-parsing with twmb; the github.com/hamba/avro/v2 dependency is kept so downstream callers continue to compile. Newly supported logical types (the mapping had them commented out under hamba): - local-timestamp-millis - local-timestamp-micros - local-timestamp-nanos - timestamp-nanos Fixes: - fixed(16)+uuid rows are no longer silently dropped — the UUIDBuilder now handles [16]byte directly. - ArrowSchemaFromAvro (deprecated wrapper) preserves logical-type annotations on fixed (uuid/decimal/duration). The wrapper now uses json.Marshal, which dispatches to each schema type's MarshalJSON; hambaAvro.Schema.String() returns Avro Parsing Canonical Form, which strips logical types by spec, and was the wrong serializer to bridge between libraries. - Heterogeneous and non-nullable multi-branch unions (e.g. ["null", A, B], ["int", "string"]) used to silently mis-decode against the wrong branch or fall through to an opaque downstream nil-field error. They now fail upfront with "unsupported avro union at ". - Nullable unions written as ["T", "null"] (with null in the second position) now resolve to T as expected; the old code always took Types()[1] and produced a Null-typed column for that ordering. - Map-of-primitive schemas (e.g. {"type":"map","values":"int"}) now parse; the old code asserted the value type to NamedSchema and panicked on primitive values. - Unknown SchemaNode types used to leave a nameless nil-typed field that exploded in the record builder; now they fail with "unhandled avro type %q", surfaced as "invalid avro schema". - Duration's []byte arm read Uint16 with gaps and overflowed uint32 on the milli multiply — gone; only the avro.Duration arm remains. - append*Data helpers no longer silently no-op or fmt.Sprint-coerce unknown inputs; they return "unsupported value of type %T". - Enum-symbol mismatches in appendBinaryDictData surface a clear error instead of a generic dictionary builder failure. - Schema-conversion errors wrap with %w so callers can errors.Unwrap past arrow.ErrInvalid. OCF wire format is unchanged: twmb/avro/ocf supports the same codecs (null, deflate, snappy, zstd) and files written by hamba remain readable. --- arrow/avro/loader.go | 3 +- arrow/avro/reader.go | 27 +- arrow/avro/reader_test.go | 129 +-------- arrow/avro/reader_types.go | 444 +++++++++--------------------- arrow/avro/schema.go | 326 +++++++++++++--------- arrow/avro/schema_test.go | 47 ++-- arrow/avro/testdata/alltypes.avsc | 9 + arrow/avro/testdata/testdata.go | 259 +++++++++++------ go.mod | 2 +- go.sum | 4 +- 10 files changed, 566 insertions(+), 684 deletions(-) diff --git a/arrow/avro/loader.go b/arrow/avro/loader.go index a7199e661..fa97c426b 100644 --- a/arrow/avro/loader.go +++ b/arrow/avro/loader.go @@ -24,7 +24,7 @@ import ( func (r *OCFReader) decodeOCFToChan() { defer close(r.avroChan) - for r.r.HasNext() { + for { select { case <-r.readerCtx.Done(): r.err = fmt.Errorf("avro decoding cancelled, %d records read", r.avroDatumCount) @@ -34,7 +34,6 @@ func (r *OCFReader) decodeOCFToChan() { err := r.r.Decode(&datum) if err != nil { if errors.Is(err, io.EOF) { - r.err = nil return } r.err = err diff --git a/arrow/avro/reader.go b/arrow/avro/reader.go index db6de6275..a731a0621 100644 --- a/arrow/avro/reader.go +++ b/arrow/avro/reader.go @@ -27,10 +27,9 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/internal/debug" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/hamba/avro/v2/ocf" "github.com/tidwall/sjson" - - avro "github.com/hamba/avro/v2" + "github.com/twmb/avro" + "github.com/twmb/avro/ocf" ) var ErrMismatchFields = errors.New("arrow/avro: number of records mismatch") @@ -47,9 +46,9 @@ type schemaEdit struct { value any } -// Reader wraps goavro/OCFReader and creates array.RecordBatches from a schema. +// OCFReader reads Avro OCF files and exposes them as array.RecordBatches. type OCFReader struct { - r *ocf.Decoder + r *ocf.Reader avroSchema string avroSchemaEdits []schemaEdit schema *arrow.Schema @@ -82,7 +81,7 @@ type OCFReader struct { // NewReader returns a reader that reads from an Avro OCF file and creates // arrow.RecordBatches from the converted avro data. func NewOCFReader(r io.Reader, opts ...Option) (*OCFReader, error) { - ocfr, err := ocf.NewDecoder(r) + ocfr, err := ocf.NewReader(r) if err != nil { return nil, fmt.Errorf("%w: could not create avro ocfreader", arrow.ErrInvalid) } @@ -108,22 +107,20 @@ func NewOCFReader(r io.Reader, opts ...Option) (*OCFReader, error) { } rr.avroSchema = schema.String() if len(rr.avroSchemaEdits) > 0 { - // execute schema edits for _, e := range rr.avroSchemaEdits { err := rr.editAvroSchema(e) if err != nil { return nil, fmt.Errorf("%w: could not edit avro schema", arrow.ErrInvalid) } } - // validate edited schema - schema, err = avro.Parse(rr.avroSchema) - if err != nil { - return nil, fmt.Errorf("%w: could not parse modified avro schema", arrow.ErrInvalid) - } } - rr.schema, err = ArrowSchemaFromAvro(schema) + rr.schema, err = ArrowSchemaFromAvroJSON(rr.avroSchema) if err != nil { - return nil, fmt.Errorf("%w: could not convert avro schema", arrow.ErrInvalid) + msg := "could not convert avro schema" + if len(rr.avroSchemaEdits) > 0 { + msg = "could not parse modified avro schema" + } + return nil, fmt.Errorf("%w: %s: %w", arrow.ErrInvalid, msg, err) } if rr.mem == nil { rr.mem = memory.DefaultAllocator @@ -147,7 +144,7 @@ func NewOCFReader(r io.Reader, opts ...Option) (*OCFReader, error) { func (rr *OCFReader) Reuse(r io.Reader, opts ...Option) error { rr.Close() rr.err = nil - ocfr, err := ocf.NewDecoder(r) + ocfr, err := ocf.NewReader(r) if err != nil { return fmt.Errorf("%w: could not create avro ocfreader", arrow.ErrInvalid) } diff --git a/arrow/avro/reader_test.go b/arrow/avro/reader_test.go index 5c57e2d62..5ab18aba3 100644 --- a/arrow/avro/reader_test.go +++ b/arrow/avro/reader_test.go @@ -19,17 +19,13 @@ package avro import ( "bytes" "encoding/json" - "fmt" "os" "path/filepath" "testing" "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/avro/testdata" - "github.com/apache/arrow-go/v18/arrow/memory" - hamba "github.com/hamba/avro/v2" - "github.com/hamba/avro/v2/ocf" + "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/stretchr/testify/assert" ) @@ -130,6 +126,10 @@ func TestReader(t *testing.T) { Name: "uuidField", Type: arrow.BinaryTypes.String, }, + { + Name: "fixedUuidField", + Type: extensions.NewUUIDType(), + }, { Name: "timemillis", Type: arrow.FixedWidthTypes.Time32ms, @@ -178,20 +178,13 @@ func TestReader(t *testing.T) { t.Fatal(err) } r := new(OCFReader) - r.avroSchema = schema.String() + r.avroSchema = schema r.editAvroSchema(schemaEdit{method: "delete", path: "fields.0"}) - schema, err = hamba.Parse(r.avroSchema) + got, err := ArrowSchemaFromAvroJSON(r.avroSchema) if err != nil { t.Fatalf("%v: could not parse modified avro schema", arrow.ErrInvalid) } - got, err := ArrowSchemaFromAvro(schema) - if err != nil { - t.Fatalf("%v", err) - } assert.Equal(t, want.String(), got.String()) - if fmt.Sprintf("%+v", want.String()) != fmt.Sprintf("%+v", got.String()) { - t.Fatalf("got=%v,\n want=%v", got.String(), want.String()) - } }) t.Run("ShouldLoadExpectedRecords", func(t *testing.T) { @@ -211,7 +204,7 @@ func TestReader(t *testing.T) { exists := ar.Next() if ar.Err() != nil { - t.Error("failed to read next record: %w", ar.Err()) + t.Errorf("failed to read next record: %v", ar.Err()) } if !exists { t.Error("no record exists") @@ -230,109 +223,3 @@ func TestReader(t *testing.T) { }) } } - -// TestOCFReaderBytesValues exercises avro `bytes` fields, both plain and as a -// ["null","bytes"] union: hamba hands the decoded value to the appenders as a -// bare []byte, which previously fell into appendBinaryData's fmt fallback and -// appended the formatted text (e.g. "[1 2 3]") instead of the payload. -func TestOCFReaderBytesValues(t *testing.T) { - schema := `{ - "type": "record", - "name": "rec", - "fields": [ - {"name": "plain", "type": "bytes"}, - {"name": "nullable", "type": ["null", "bytes"]} - ] - }` - payload := []byte{0x00, 0x01, 0xfe, 0xff} - - var buf bytes.Buffer - enc, err := ocf.NewEncoder(schema, &buf) - assert.NoError(t, err) - assert.NoError(t, enc.Encode(map[string]any{ - "plain": payload, - "nullable": map[string]any{"bytes": payload}, - })) - assert.NoError(t, enc.Encode(map[string]any{ - "plain": []byte{}, - "nullable": nil, - })) - assert.NoError(t, enc.Close()) - - ar, err := NewOCFReader(bytes.NewReader(buf.Bytes()), WithChunk(-1)) - assert.NoError(t, err) - defer ar.Close() - - assert.True(t, ar.Next()) - assert.NoError(t, ar.Err()) - rec := ar.RecordBatch() - - plain := rec.Column(0).(*array.Binary) - assert.Equal(t, payload, plain.Value(0)) - assert.Equal(t, []byte{}, plain.Value(1)) - - nullable := rec.Column(1).(*array.Binary) - assert.Equal(t, payload, nullable.Value(0)) - assert.True(t, nullable.IsNull(1)) -} - -// Types outside what the hamba decoder produces must error rather than append -// a fmt-formatted rendering of the value. -func TestAppendBinaryAndStringDataUnexpectedTypes(t *testing.T) { - bb := array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary) - defer bb.Release() - - assert.NoError(t, appendBinaryData(bb, []byte{0x01})) - assert.NoError(t, appendBinaryData(bb, nil)) - assert.NoError(t, appendBinaryData(bb, map[string]any{"bytes": []byte{0x02}})) - assert.ErrorContains(t, appendBinaryData(bb, 42), "unexpected type int") - assert.ErrorContains(t, appendBinaryData(bb, map[string]any{"bytes": "text"}), "unexpected type string") - assert.Equal(t, 3, bb.Len()) - - sb := array.NewStringBuilder(memory.DefaultAllocator) - defer sb.Release() - - assert.NoError(t, appendStringData(sb, "ok")) - assert.NoError(t, appendStringData(sb, []byte("ok"))) - assert.NoError(t, appendStringData(sb, nil)) - assert.NoError(t, appendStringData(sb, map[string]any{"string": "ok"})) - assert.ErrorContains(t, appendStringData(sb, 42), "unexpected type int") - assert.ErrorContains(t, appendStringData(sb, map[string]any{"string": 42}), "unexpected type int") - assert.Equal(t, 4, sb.Len()) -} - -// loadDatum must surface appender errors from nested paths (map values, -// list items), not only from top-level and struct fields. -func TestLoadDatumPropagatesNestedAppendErrors(t *testing.T) { - newLoader := func(t *testing.T, avroSchema string) (*dataLoader, *array.RecordBuilder) { - t.Helper() - schema, err := hamba.Parse(avroSchema) - assert.NoError(t, err) - arrowSchema, err := ArrowSchemaFromAvro(schema) - assert.NoError(t, err) - bld := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema) - pos := newFieldPos() - ldr := newDataLoader() - for idx, fb := range bld.Fields() { - mapFieldBuilders(fb, arrowSchema.Field(idx), pos) - } - ldr.drawTree(pos) - return ldr, bld - } - - t.Run("map value", func(t *testing.T) { - ldr, bld := newLoader(t, `{"type":"record","name":"r","fields":[ - {"name":"m","type":{"type":"map","values":"bytes"}}]}`) - defer bld.Release() - assert.NoError(t, ldr.loadDatum(map[string]any{"m": map[string]any{"k": []byte{0x01}}})) - assert.ErrorContains(t, ldr.loadDatum(map[string]any{"m": map[string]any{"k": 42}}), "unexpected type int") - }) - - t.Run("list item", func(t *testing.T) { - ldr, bld := newLoader(t, `{"type":"record","name":"r","fields":[ - {"name":"l","type":{"type":"array","items":"bytes"}}]}`) - defer bld.Release() - assert.NoError(t, ldr.loadDatum(map[string]any{"l": []any{[]byte{0x01}}})) - assert.ErrorContains(t, ldr.loadDatum(map[string]any{"l": []any{42}}), "unexpected type int") - }) -} diff --git a/arrow/avro/reader_types.go b/arrow/avro/reader_types.go index 45a7b145d..da13b03d6 100644 --- a/arrow/avro/reader_types.go +++ b/arrow/avro/reader_types.go @@ -17,8 +17,6 @@ package avro import ( - "bytes" - "encoding/binary" "errors" "fmt" "math/big" @@ -31,7 +29,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/decimal256" "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/arrow-go/v18/arrow/memory" - hamba "github.com/hamba/avro/v2" + avro "github.com/twmb/avro" ) type dataLoader struct { @@ -92,21 +90,10 @@ func (d *dataLoader) drawTree(field *fieldPos) { // Since array.StructBuilder.AppendNull() will recursively append null to all of the // struct's fields, in the case of nil being passed to a struct's builderFunc it will // return a ErrNullStructData error to signal that all its sub-fields can be skipped. -// filterNullStruct drops ErrNullStructData, which signals a null struct -// whose sub-fields can be skipped rather than a failure. -func filterNullStruct(err error) error { - if err == ErrNullStructData { - return nil - } - return err -} - func (d *dataLoader) loadDatum(data any) error { if d.list == nil && d.mapField == nil { if d.mapValue != nil { - if err := filterNullStruct(d.mapValue.appendFunc(data)); err != nil { - return err - } + d.mapValue.appendFunc(data) } var NullParent *fieldPos for _, f := range d.fields { @@ -147,9 +134,7 @@ func (d *dataLoader) loadDatum(data any) error { } } else { for _, e := range dt { - if err := d.children[0].loadDatum(e); err != nil { - return err - } + d.children[0].loadDatum(e) } } case map[string]any: @@ -167,24 +152,16 @@ func (d *dataLoader) loadDatum(data any) error { } for _, c := range d.children { if c.list != nil { - if err := c.loadDatum(c.list.getValue(data)); err != nil { - return err - } + c.loadDatum(c.list.getValue(data)) } if c.mapField != nil { switch dt := data.(type) { case nil: - if err := c.loadDatum(dt); err != nil { - return err - } + c.loadDatum(dt) case map[string]any: - if err := c.loadDatum(c.mapField.getValue(dt)); err != nil { - return err - } + c.loadDatum(c.mapField.getValue(dt)) default: - if err := c.loadDatum(c.mapField.getValue(data)); err != nil { - return err - } + c.loadDatum(c.mapField.getValue(data)) } } } @@ -192,18 +169,12 @@ func (d *dataLoader) loadDatum(data any) error { if d.list != nil { switch dt := data.(type) { case nil: - if err := filterNullStruct(d.list.appendFunc(dt)); err != nil { - return err - } + d.list.appendFunc(dt) case []any: - if err := filterNullStruct(d.list.appendFunc(dt)); err != nil { - return err - } + d.list.appendFunc(dt) for _, e := range dt { if d.item != nil { - if err := filterNullStruct(d.item.appendFunc(e)); err != nil { - return err - } + d.item.appendFunc(e) } var NullParent *fieldPos for _, f := range d.fields { @@ -221,26 +192,18 @@ func (d *dataLoader) loadDatum(data any) error { } for _, c := range d.children { if c.list != nil { - if err := c.loadDatum(c.list.getValue(e)); err != nil { - return err - } + c.loadDatum(c.list.getValue(e)) } if c.mapField != nil { - if err := c.loadDatum(c.mapField.getValue(e)); err != nil { - return err - } + c.loadDatum(c.mapField.getValue(e)) } } } case map[string]any: - if err := filterNullStruct(d.list.appendFunc(dt["array"])); err != nil { - return err - } + d.list.appendFunc(dt["array"]) for _, e := range dt["array"].([]any) { if d.item != nil { - if err := filterNullStruct(d.item.appendFunc(e)); err != nil { - return err - } + d.item.appendFunc(e) } var NullParent *fieldPos for _, f := range d.fields { @@ -257,40 +220,27 @@ func (d *dataLoader) loadDatum(data any) error { } } for _, c := range d.children { - if err := c.loadDatum(c.list.getValue(e)); err != nil { - return err - } + c.loadDatum(c.list.getValue(e)) } } default: - if err := filterNullStruct(d.list.appendFunc(data)); err != nil { - return err - } - if err := filterNullStruct(d.item.appendFunc(dt)); err != nil { - return err - } + d.list.appendFunc(data) + d.item.appendFunc(dt) } } if d.mapField != nil { switch dt := data.(type) { case nil: - if err := filterNullStruct(d.mapField.appendFunc(dt)); err != nil { - return err - } + d.mapField.appendFunc(dt) case map[string]any: - if err := filterNullStruct(d.mapField.appendFunc(dt)); err != nil { - return err - } + + d.mapField.appendFunc(dt) for k, v := range dt { - if err := filterNullStruct(d.mapKey.appendFunc(k)); err != nil { - return err - } + d.mapKey.appendFunc(k) if d.mapValue != nil { - if err := filterNullStruct(d.mapValue.appendFunc(v)); err != nil { - return err - } - } else if err := d.children[0].loadDatum(v); err != nil { - return err + d.mapValue.appendFunc(v) + } else { + d.children[0].loadDatum(v) } } } @@ -450,8 +400,7 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { case *array.BinaryDictionaryBuilder: // has metadata for Avro enum symbols f.appendFunc = func(data interface{}) error { - appendBinaryDictData(bt, data) - return nil + return appendBinaryDictData(bt, data) } // add Avro enum symbols to builder sb := array.NewStringBuilder(memory.DefaultAllocator) @@ -462,13 +411,11 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { bt.InsertStringDictValues(sa) case *array.BooleanBuilder: f.appendFunc = func(data interface{}) error { - appendBoolData(bt, data) - return nil + return appendBoolData(bt, data) } case *array.Date32Builder: f.appendFunc = func(data interface{}) error { - appendDate32Data(bt, data) - return nil + return appendDate32Data(bt, data) } case *array.Decimal128Builder: f.appendFunc = func(data interface{}) error { @@ -476,11 +423,7 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { if !ok { return nil } - err := appendDecimal128Data(bt, data, typ) - if err != nil { - return err - } - return nil + return appendDecimal128Data(bt, data, typ) } case *array.Decimal256Builder: f.appendFunc = func(data interface{}) error { @@ -488,54 +431,31 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { if !ok { return nil } - err := appendDecimal256Data(bt, data, typ) - if err != nil { - return err - } - return nil + return appendDecimal256Data(bt, data, typ) } case *extensions.UUIDBuilder: f.appendFunc = func(data interface{}) error { - switch dt := data.(type) { - case nil: - bt.AppendNull() - case string: - err := bt.AppendValueFromString(dt) - if err != nil { - return err - } - case []byte: - err := bt.AppendValueFromString(string(dt)) - if err != nil { - return err - } - } - return nil + return appendUUIDData(bt, data, field.Name) } case *array.FixedSizeBinaryBuilder: f.appendFunc = func(data interface{}) error { - appendFixedSizeBinaryData(bt, data) - return nil + return appendFixedSizeBinaryData(bt, data) } case *array.Float32Builder: f.appendFunc = func(data interface{}) error { - appendFloat32Data(bt, data) - return nil + return appendFloat32Data(bt, data) } case *array.Float64Builder: f.appendFunc = func(data interface{}) error { - appendFloat64Data(bt, data) - return nil + return appendFloat64Data(bt, data) } case *array.Int32Builder: f.appendFunc = func(data interface{}) error { - appendInt32Data(bt, data) - return nil + return appendInt32Data(bt, data) } case *array.Int64Builder: f.appendFunc = func(data interface{}) error { - appendInt64Data(bt, data) - return nil + return appendInt64Data(bt, data) } case *array.LargeListBuilder: vb := bt.ValueBuilder() @@ -593,8 +513,7 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { } case *array.MonthDayNanoIntervalBuilder: f.appendFunc = func(data interface{}) error { - appendDurationData(bt, data) - return nil + return appendDurationData(bt, data) } case *array.StringBuilder: f.appendFunc = func(data interface{}) error { @@ -620,131 +539,108 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { } case *array.Time32Builder: f.appendFunc = func(data interface{}) error { - appendTime32Data(bt, data) - return nil + return appendTime32Data(bt, data) } case *array.Time64Builder: f.appendFunc = func(data interface{}) error { - appendTime64Data(bt, data) - return nil + return appendTime64Data(bt, data) } case *array.TimestampBuilder: f.appendFunc = func(data interface{}) error { - appendTimestampData(bt, data) - return nil + return appendTimestampData(bt, data) } } } -func appendBinaryData(b *array.BinaryBuilder, data interface{}) error { +// appendUUIDData accepts the two shapes a UUID may arrive as: a [16]byte +// (fixed(16)+uuid) or a hex-dash string (string+uuid). Other byte lengths +// are rejected rather than re-interpreted. +func appendUUIDData(b *extensions.UUIDBuilder, data any, fieldName string) error { switch dt := data.(type) { case nil: b.AppendNull() + case string: + return b.AppendValueFromString(dt) + case [16]byte: + b.AppendBytes(dt) case []byte: - b.Append(dt) - case map[string]any: - switch ct := dt["bytes"].(type) { - case nil: - b.AppendNull() - case []byte: - b.Append(ct) + switch len(dt) { + case 16: + b.AppendBytes([16]byte(dt)) + case 36: + return b.AppendValueFromString(string(dt)) default: - return fmt.Errorf("unexpected type %T for avro bytes union value", ct) + return fmt.Errorf("avro: %d-byte value cannot be a UUID for column %q", len(dt), fieldName) } default: - return fmt.Errorf("unexpected type %T for avro bytes value", data) + return fmt.Errorf("avro: unsupported value of type %T for UUID column %q", data, fieldName) + } + return nil +} + +func appendBinaryData(b *array.BinaryBuilder, data interface{}) error { + switch dt := data.(type) { + case nil: + b.AppendNull() + case []byte: + b.Append(dt) + default: + return fmt.Errorf("avro: unsupported value of type %T for Binary column", data) } return nil } -func appendBinaryDictData(b *array.BinaryDictionaryBuilder, data interface{}) { +func appendBinaryDictData(b *array.BinaryDictionaryBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case string: - b.AppendString(dt) - case map[string]any: - switch v := dt["string"].(type) { - case nil: - b.AppendNull() - case string: - b.AppendString(v) + if err := b.AppendString(dt); err != nil { + return fmt.Errorf("avro: enum symbol %q is not in the dictionary (schema/data mismatch?): %w", dt, err) } + default: + return fmt.Errorf("avro: unsupported value of type %T for Dictionary column", data) } + return nil } -func appendBoolData(b *array.BooleanBuilder, data interface{}) { +func appendBoolData(b *array.BooleanBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case bool: b.Append(dt) - case map[string]any: - switch v := dt["boolean"].(type) { - case nil: - b.AppendNull() - case bool: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Boolean column", data) } + return nil } -func appendDate32Data(b *array.Date32Builder, data interface{}) { +func appendDate32Data(b *array.Date32Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int32: - b.Append(arrow.Date32(dt)) - case map[string]any: - switch v := dt["int"].(type) { - case nil: - b.AppendNull() - case int32: - b.Append(arrow.Date32(v)) - } case time.Time: b.Append(arrow.Date32FromTime(dt)) + default: + return fmt.Errorf("avro: unsupported value of type %T for Date32 column", data) } + return nil } func appendDecimal128Data(b *array.Decimal128Builder, data interface{}, typ arrow.DecimalType) error { switch dt := data.(type) { case nil: b.AppendNull() - case []byte: - buf := bytes.NewBuffer(dt) - if len(dt) <= 38 { - var intData int64 - err := binary.Read(buf, binary.BigEndian, &intData) - if err != nil { - return err - } - b.Append(decimal128.FromI64(intData)) - } else { - var bigIntData big.Int - b.Append(decimal128.FromBigInt(bigIntData.SetBytes(buf.Bytes()))) - } - case map[string]any: - buf := bytes.NewBuffer(dt["bytes"].([]byte)) - if len(dt["bytes"].([]byte)) <= 38 { - var intData int64 - err := binary.Read(buf, binary.BigEndian, &intData) - if err != nil { - return err - } - b.Append(decimal128.FromI64(intData)) - } else { - var bigIntData big.Int - b.Append(decimal128.FromBigInt(bigIntData.SetBytes(buf.Bytes()))) - } case *big.Rat: v := bigRatToBigInt(dt, typ) - if v.IsInt64() { b.Append(decimal128.FromI64(v.Int64())) } else { b.Append(decimal128.FromBigInt(v)) } + default: + return fmt.Errorf("avro: unsupported value of type %T for Decimal128 column", data) } return nil } @@ -753,16 +649,10 @@ func appendDecimal256Data(b *array.Decimal256Builder, data interface{}, typ arro switch dt := data.(type) { case nil: b.AppendNull() - case []byte: - var bigIntData big.Int - buf := bytes.NewBuffer(dt) - b.Append(decimal256.FromBigInt(bigIntData.SetBytes(buf.Bytes()))) - case map[string]any: - var bigIntData big.Int - buf := bytes.NewBuffer(dt["bytes"].([]byte)) - b.Append(decimal256.FromBigInt(bigIntData.SetBytes(buf.Bytes()))) case *big.Rat: b.Append(decimal256.FromBigInt(bigRatToBigInt(dt, typ))) + default: + return fmt.Errorf("avro: unsupported value of type %T for Decimal256 column", data) } return nil } @@ -779,129 +669,88 @@ func bigRatToBigInt(dt *big.Rat, typ arrow.DecimalType) *big.Int { // Avro duration logical type annotates Avro fixed type of size 12, which stores three little-endian // unsigned integers that represent durations at different granularities of time. The first stores // a number in months, the second stores a number in days, and the third stores a number in milliseconds. -func appendDurationData(b *array.MonthDayNanoIntervalBuilder, data interface{}) { +func appendDurationData(b *array.MonthDayNanoIntervalBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case []byte: - dur := new(arrow.MonthDayNanoInterval) - dur.Months = int32(binary.LittleEndian.Uint16(dt[:3])) - dur.Days = int32(binary.LittleEndian.Uint16(dt[4:7])) - dur.Nanoseconds = int64(binary.LittleEndian.Uint32(dt[8:]) * 1000000) - b.Append(*dur) - case map[string]any: - switch dtb := dt["bytes"].(type) { - case nil: - b.AppendNull() - case []byte: - dur := new(arrow.MonthDayNanoInterval) - dur.Months = int32(binary.LittleEndian.Uint16(dtb[:3])) - dur.Days = int32(binary.LittleEndian.Uint16(dtb[4:7])) - dur.Nanoseconds = int64(binary.LittleEndian.Uint32(dtb[8:]) * 1000000) - b.Append(*dur) - } - case hamba.LogicalDuration: + case avro.Duration: b.Append(arrow.MonthDayNanoInterval{ Months: int32(dt.Months), Days: int32(dt.Days), Nanoseconds: int64(dt.Milliseconds) * int64(time.Millisecond), }) + default: + return fmt.Errorf("avro: unsupported value of type %T for Duration column", data) } + return nil } -func appendFixedSizeBinaryData(b *array.FixedSizeBinaryBuilder, data interface{}) { +func appendFixedSizeBinaryData(b *array.FixedSizeBinaryBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case []byte: b.Append(dt) - case map[string]any: - switch v := dt["bytes"].(type) { - case nil: - b.AppendNull() - case []byte: - b.Append(v) - } default: + // fixed(N) may arrive as a Go [N]byte; accept any byte-array via reflection. v := reflect.ValueOf(data) if v.Kind() == reflect.Array && v.Type().Elem().Kind() == reflect.Uint8 { - bytes := make([]byte, v.Len()) - reflect.Copy(reflect.ValueOf(bytes), v) - b.Append(bytes) + buf := make([]byte, v.Len()) + reflect.Copy(reflect.ValueOf(buf), v) + b.Append(buf) + return nil } + return fmt.Errorf("avro: unsupported value of type %T for FixedSizeBinary column", data) } + return nil } -func appendFloat32Data(b *array.Float32Builder, data interface{}) { +func appendFloat32Data(b *array.Float32Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case float32: b.Append(dt) - case map[string]any: - switch v := dt["float"].(type) { - case nil: - b.AppendNull() - case float32: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Float32 column", data) } + return nil } -func appendFloat64Data(b *array.Float64Builder, data interface{}) { +func appendFloat64Data(b *array.Float64Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case float64: b.Append(dt) - case map[string]any: - switch v := dt["double"].(type) { - case nil: - b.AppendNull() - case float64: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Float64 column", data) } + return nil } -func appendInt32Data(b *array.Int32Builder, data interface{}) { +func appendInt32Data(b *array.Int32Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int: - b.Append(int32(dt)) case int32: b.Append(dt) - case map[string]any: - switch v := dt["int"].(type) { - case nil: - b.AppendNull() - case int: - b.Append(int32(v)) - case int32: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Int32 column", data) } + return nil } -func appendInt64Data(b *array.Int64Builder, data interface{}) { +func appendInt64Data(b *array.Int64Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int: - b.Append(int64(dt)) case int64: b.Append(dt) - case map[string]any: - switch v := dt["long"].(type) { - case nil: - b.AppendNull() - case int: - b.Append(int64(v)) - case int64: - b.Append(v) - } + default: + return fmt.Errorf("avro: unsupported value of type %T for Int64 column", data) } + return nil } func appendStringData(b *array.StringBuilder, data interface{}) error { @@ -910,85 +759,48 @@ func appendStringData(b *array.StringBuilder, data interface{}) error { b.AppendNull() case string: b.Append(dt) - case []byte: - b.Append(string(dt)) - case map[string]any: - switch v := dt["string"].(type) { - case nil: - b.AppendNull() - case string: - b.Append(v) - default: - return fmt.Errorf("unexpected type %T for avro string union value", v) - } default: - return fmt.Errorf("unexpected type %T for avro string value", data) + return fmt.Errorf("avro: unsupported value of type %T for String column", data) } return nil } -func appendTime32Data(b *array.Time32Builder, data interface{}) { +func appendTime32Data(b *array.Time32Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int32: - b.Append(arrow.Time32(dt)) - case map[string]any: - switch v := dt["int"].(type) { - case nil: - b.AppendNull() - case int32: - b.Append(arrow.Time32(v)) - } case time.Duration: b.Append(arrow.Time32(dt.Milliseconds())) + default: + return fmt.Errorf("avro: unsupported value of type %T for Time32 column", data) } + return nil } -func appendTime64Data(b *array.Time64Builder, data interface{}) { +func appendTime64Data(b *array.Time64Builder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int64: - b.Append(arrow.Time64(dt)) - case map[string]any: - switch v := dt["long"].(type) { - case nil: - b.AppendNull() - case int64: - b.Append(arrow.Time64(v)) - } case time.Duration: b.Append(arrow.Time64(dt.Microseconds())) + default: + return fmt.Errorf("avro: unsupported value of type %T for Time64 column", data) } + return nil } -func appendTimestampData(b *array.TimestampBuilder, data interface{}) { +func appendTimestampData(b *array.TimestampBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() - case int64: - b.Append(arrow.Timestamp(dt)) - case map[string]any: - switch v := dt["long"].(type) { - case nil: - b.AppendNull() - case int64: - b.Append(arrow.Timestamp(v)) - } case time.Time: - tt := b.Type().(*arrow.TimestampType) - // hamba decodes a local-timestamp logical type into a time.Time whose wall-clock - // fields hold the intended value but whose instant is offset by the decoder's local - // zone. Arrow stores local (zone-less) timestamps as the wall clock read in UTC, so - // reinterpret the fields in UTC to keep the value zone-independent. - if tt.TimeZone == "" { - dt = time.Date(dt.Year(), dt.Month(), dt.Day(), dt.Hour(), dt.Minute(), dt.Second(), dt.Nanosecond(), time.UTC) - } - v, err := arrow.TimestampFromTime(dt, tt.Unit) + v, err := arrow.TimestampFromTime(dt, b.Type().(*arrow.TimestampType).Unit) if err != nil { - panic(err) + return err } b.Append(v) + default: + return fmt.Errorf("avro: unsupported value of type %T for Timestamp column", data) } + return nil } diff --git a/arrow/avro/schema.go b/arrow/avro/schema.go index 6523096c8..13214ca23 100644 --- a/arrow/avro/schema.go +++ b/arrow/avro/schema.go @@ -18,6 +18,7 @@ package avro import ( + "encoding/json" "fmt" "math" "strconv" @@ -26,24 +27,34 @@ import ( "github.com/apache/arrow-go/v18/arrow/decimal128" "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/arrow-go/v18/internal/utils" - avro "github.com/hamba/avro/v2" + hambaAvro "github.com/hamba/avro/v2" + avro "github.com/twmb/avro" ) +// builtinAvroTypes is the set of Type field values that mean "this SchemaNode +// is the inline definition of an Avro type." Anything else in node.Type is +// treated as a named-type reference to a previously-seen record/enum/fixed. +var builtinAvroTypes = map[string]struct{}{ + "null": {}, "boolean": {}, "int": {}, "long": {}, + "float": {}, "double": {}, "bytes": {}, "string": {}, + "record": {}, "enum": {}, "array": {}, "map": {}, + "fixed": {}, "union": {}, +} + type schemaNode struct { - name string - parent *schemaNode - schema avro.Schema - union bool - nullable bool - childrens []*schemaNode - arrowField arrow.Field - schemaCache *avro.SchemaCache - index, depth int32 + name string + parent *schemaNode + node avro.SchemaNode + union bool + nullable bool + childrens []*schemaNode + arrowField arrow.Field + namedCache map[string]avro.SchemaNode + index int32 } func newSchemaNode() *schemaNode { - var schemaCache avro.SchemaCache - return &schemaNode{name: "", index: -1, schemaCache: &schemaCache} + return &schemaNode{index: -1, namedCache: map[string]avro.SchemaNode{}} } func (node *schemaNode) schemaPath() string { @@ -56,33 +67,84 @@ func (node *schemaNode) schemaPath() string { return path } -func (node *schemaNode) newChild(n string, s avro.Schema) *schemaNode { +func (node *schemaNode) newChild(n string, s avro.SchemaNode) *schemaNode { child := &schemaNode{ - name: n, - parent: node, - schema: s, - schemaCache: node.schemaCache, - index: int32(len(node.childrens)), - depth: node.depth + 1, + name: n, + parent: node, + node: s, + namedCache: node.namedCache, + index: int32(len(node.childrens)), } node.childrens = append(node.childrens, child) return child } func (node *schemaNode) children() []*schemaNode { return node.childrens } -// func (node *schemaNode) nodeName() string { return node.name } +// rememberNamed adds a record/enum/fixed SchemaNode to the named-type cache +// under both its short name and (if a namespace is present) its full name, +// so later references like {"type": "Address"} or {"type": "ns.Address"} +// resolve back to the original definition. +func (node *schemaNode) rememberNamed(s avro.SchemaNode) { + if s.Name == "" { + return + } + node.namedCache[s.Name] = s + if s.Namespace != "" { + node.namedCache[s.Namespace+"."+s.Name] = s + } +} + +// resolveRef replaces s with its inline definition if s.Type is a named-type +// reference rather than a builtin Avro type. atField, when non-empty, names +// the field this reference appears in and is included in the panic so the +// user can locate the offending entry. +func (node *schemaNode) resolveRef(s avro.SchemaNode, atField string) avro.SchemaNode { + if _, ok := builtinAvroTypes[s.Type]; ok { + return s + } + if def, ok := node.namedCache[s.Type]; ok { + return def + } + loc := node.schemaPath() + if atField != "" { + loc += "." + atField + } + panic(fmt.Errorf("unknown named type %q referenced at %s", s.Type, loc)) +} + +// ArrowSchemaFromAvroJSON parses an Avro schema given as JSON text and returns +// the equivalent Arrow schema. +func ArrowSchemaFromAvroJSON(schemaJSON string) (*arrow.Schema, error) { + schema, err := avro.Parse(schemaJSON) + if err != nil { + return nil, err + } + return arrowSchemaFromAvroInternal(schema) +} + +// ArrowSchemaFromAvro returns a new Arrow schema from a parsed Avro schema. +// +// Deprecated: Use [ArrowSchemaFromAvroJSON] instead — it does not couple +// callers to a particular Avro library through its signature. +func ArrowSchemaFromAvro(schema hambaAvro.Schema) (*arrow.Schema, error) { + js, err := json.Marshal(schema) + if err != nil { + return nil, fmt.Errorf("%w: could not serialize hamba avro schema: %w", arrow.ErrInvalid, err) + } + return ArrowSchemaFromAvroJSON(string(js)) +} -// ArrowSchemaFromAvro returns a new Arrow schema from an Avro schema -func ArrowSchemaFromAvro(schema avro.Schema) (s *arrow.Schema, err error) { +func arrowSchemaFromAvroInternal(schema *avro.Schema) (s *arrow.Schema, err error) { defer func() { if r := recover(); r != nil { s = nil err = utils.FormatRecoveredError("invalid avro schema", r) } }() + root := schema.Root() n := newSchemaNode() - n.schema = schema - c := n.newChild(n.schema.(avro.NamedSchema).Name(), n.schema) + n.node = root + c := n.newChild(root.Name, root) arrowSchemafromAvro(c) var fields []arrow.Field for _, g := range c.children() { @@ -93,16 +155,16 @@ func ArrowSchemaFromAvro(schema avro.Schema) (s *arrow.Schema, err error) { } func arrowSchemafromAvro(n *schemaNode) { - if ns, ok := n.schema.(avro.NamedSchema); ok { - n.schemaCache.Add(ns.Name(), ns) + n.node = n.resolveRef(n.node, "") + if n.node.Name != "" { + n.rememberNamed(n.node) } - switch st := n.schema.Type(); st { + switch st := n.node.Type; st { case "record": iterateFields(n) case "enum": - n.schemaCache.Add(n.schema.(avro.NamedSchema).Name(), n.schema.(*avro.EnumSchema)) symbols := make(map[string]string) - for index, symbol := range n.schema.(avro.PropertySchema).(*avro.EnumSchema).Symbols() { + for index, symbol := range n.node.Symbols { k := strconv.FormatInt(int64(index), 10) symbols[k] = symbol } @@ -118,9 +180,12 @@ func arrowSchemafromAvro(n *schemaNode) { } n.arrowField = buildArrowField(n, &dt, arrow.MetadataFrom(symbols)) case "array": - // logical items type - c := n.newChild(n.name, n.schema.(*avro.ArraySchema).Items()) - if isLogicalSchemaType(n.schema.(*avro.ArraySchema).Items()) { + if n.node.Items == nil { + panic(fmt.Errorf("avro array schema at %s has no 'items'", n.schemaPath())) + } + items := *n.node.Items + c := n.newChild(n.name, items) + if isLogicalSchemaType(items) { avroLogicalToArrowField(c) } else { arrowSchemafromAvro(c) @@ -134,62 +199,58 @@ func arrowSchemafromAvro(n *schemaNode) { } n.arrowField = buildArrowField(n, typ, c.arrowField.Metadata) case "map": - n.schemaCache.Add(n.schema.(*avro.MapSchema).Values().(avro.NamedSchema).Name(), n.schema.(*avro.MapSchema).Values()) - c := n.newChild(n.name, n.schema.(*avro.MapSchema).Values()) + if n.node.Values == nil { + panic(fmt.Errorf("avro map schema at %s has no 'values'", n.schemaPath())) + } + values := *n.node.Values + c := n.newChild(n.name, values) arrowSchemafromAvro(c) n.arrowField = buildArrowField(n, arrow.MapOf(arrow.BinaryTypes.String, c.arrowField.Type), c.arrowField.Metadata) case "union": - us := n.schema.(*avro.UnionSchema) - if us.Nullable() { - if len(us.Types()) > 1 { - n.schema = us.Types()[1] - n.union = true - n.nullable = true - arrowSchemafromAvro(n) - } - } else { - panic(fmt.Errorf("complex (non-nullable) avro union at '%v' is not supported", n.schemaPath())) + branch, ok := nullableBranch(n.node) + if !ok { + panic(fmt.Errorf("unsupported avro union at %s: only ['null', T] unions with exactly one non-null branch are supported", n.schemaPath())) } + n.node = branch + n.union = true + n.nullable = true + arrowSchemafromAvro(n) // Avro "fixed" field type = Arrow FixedSize Primitive BinaryType case "fixed": - n.schemaCache.Add(n.schema.(avro.NamedSchema).Name(), n.schema.(*avro.FixedSchema)) - if isLogicalSchemaType(n.schema) { + if isLogicalSchemaType(n.node) { avroLogicalToArrowField(n) } else { - n.arrowField = buildArrowField(n, &arrow.FixedSizeBinaryType{ByteWidth: n.schema.(*avro.FixedSchema).Size()}, arrow.Metadata{}) + n.arrowField = buildArrowField(n, &arrow.FixedSizeBinaryType{ByteWidth: n.node.Size}, arrow.Metadata{}) } case "string", "bytes", "int", "long": - if isLogicalSchemaType(n.schema) { + if isLogicalSchemaType(n.node) { avroLogicalToArrowField(n) } else { n.arrowField = buildArrowField(n, avroPrimitiveToArrowType(string(st)), arrow.Metadata{}) } case "float", "double", "boolean": n.arrowField = buildArrowField(n, avroPrimitiveToArrowType(string(st)), arrow.Metadata{}) - case "": - refSchema := n.schemaCache.Get(string(n.schema.(*avro.RefSchema).Schema().Name())) - if refSchema == nil { - panic(fmt.Errorf("could not find schema for '%v' in schema cache - %v", n.schemaPath(), n.schema.(*avro.RefSchema).Schema().Name())) - } - n.schema = refSchema - arrowSchemafromAvro(n) case "null": - n.schemaCache.Add(n.schema.(*avro.MapSchema).Values().(avro.NamedSchema).Name(), &avro.NullSchema{}) n.nullable = true n.arrowField = buildArrowField(n, arrow.Null, arrow.Metadata{}) + default: + panic(fmt.Errorf("unhandled avro type %q at %s", st, n.schemaPath())) } } -// iterate record Fields() +// iterate record Fields func iterateFields(n *schemaNode) { - for _, f := range n.schema.(*avro.RecordSchema).Fields() { - switch ft := f.Type().(type) { + for _, f := range n.node.Fields { + ft := n.resolveRef(f.Type, f.Name) + switch ft.Type { // Avro "array" field type - case *avro.ArraySchema: - n.schemaCache.Add(f.Name(), ft.Items()) - // logical items type - c := n.newChild(f.Name(), ft.Items()) - if isLogicalSchemaType(ft.Items()) { + case "array": + if ft.Items == nil { + panic(fmt.Errorf("avro array field %s.%s has no 'items'", n.schemaPath(), f.Name)) + } + items := *ft.Items + c := n.newChild(f.Name, items) + if isLogicalSchemaType(items) { avroLogicalToArrowField(c) } else { arrowSchemafromAvro(c) @@ -201,11 +262,11 @@ func iterateFields(n *schemaNode) { c.arrowField = arrow.Field{Name: c.name, Type: arrow.ListOfNonNullable(c.arrowField.Type), Metadata: c.arrowField.Metadata} } // Avro "enum" field type = Arrow dictionary type - case *avro.EnumSchema: - n.schemaCache.Add(f.Type().(*avro.EnumSchema).Name(), f.Type()) - c := n.newChild(f.Name(), f.Type()) + case "enum": + n.rememberNamed(ft) + c := n.newChild(f.Name, ft) symbols := make(map[string]string) - for index, symbol := range ft.Symbols() { + for index, symbol := range ft.Symbols { k := strconv.FormatInt(int64(index), 10) symbols[k] = symbol } @@ -221,46 +282,43 @@ func iterateFields(n *schemaNode) { } c.arrowField = buildArrowField(c, &dt, arrow.MetadataFrom(symbols)) // Avro "fixed" field type = Arrow FixedSize Primitive BinaryType - case *avro.FixedSchema: - n.schemaCache.Add(f.Name(), f.Type()) - c := n.newChild(f.Name(), f.Type()) - if isLogicalSchemaType(f.Type()) { + case "fixed": + n.rememberNamed(ft) + c := n.newChild(f.Name, ft) + if isLogicalSchemaType(ft) { avroLogicalToArrowField(c) } else { arrowSchemafromAvro(c) } - case *avro.RecordSchema: - n.schemaCache.Add(f.Name(), f.Type()) - c := n.newChild(f.Name(), f.Type()) + case "record": + n.rememberNamed(ft) + c := n.newChild(f.Name, ft) iterateFields(c) - // Avro "map" field type - KVP with value of one type - keys are strings - case *avro.MapSchema: - n.schemaCache.Add(f.Name(), ft.Values()) - c := n.newChild(f.Name(), ft.Values()) + // Avro "map" field type - KVP with value of one type - keys are strings + case "map": + if ft.Values == nil { + panic(fmt.Errorf("avro map field %s.%s has no 'values'", n.schemaPath(), f.Name)) + } + values := *ft.Values + c := n.newChild(f.Name, values) arrowSchemafromAvro(c) c.arrowField = buildArrowField(c, arrow.MapOf(arrow.BinaryTypes.String, c.arrowField.Type), c.arrowField.Metadata) - case *avro.UnionSchema: - if ft.Nullable() { - if len(ft.Types()) > 1 { - n.schemaCache.Add(f.Name(), ft.Types()[1]) - c := n.newChild(f.Name(), ft.Types()[1]) - c.union = true - c.nullable = true - arrowSchemafromAvro(c) - } - } else { - panic(fmt.Errorf("complex (non-nullable) avro union in field '%v' is not supported", f.Name())) + case "union": + branch, ok := nullableBranch(ft) + if !ok { + panic(fmt.Errorf("unsupported avro union at %s.%s: only ['null', T] unions with exactly one non-null branch are supported", n.schemaPath(), f.Name)) } + c := n.newChild(f.Name, branch) + c.union = true + c.nullable = true + arrowSchemafromAvro(c) default: - n.schemaCache.Add(f.Name(), f.Type()) - if isLogicalSchemaType(f.Type()) { - c := n.newChild(f.Name(), f.Type()) + c := n.newChild(f.Name, ft) + if isLogicalSchemaType(ft) { avroLogicalToArrowField(c) } else { - c := n.newChild(f.Name(), f.Type()) arrowSchemafromAvro(c) } - } } var fields []arrow.Field @@ -268,7 +326,7 @@ func iterateFields(n *schemaNode) { fields = append(fields, child.arrowField) } - namedSchema, ok := isNamedSchema(n.schema) + namedSchema, ok := isNamedSchema(n.node) var md arrow.Metadata if ok && namedSchema != n.name+"_data" && n.union { @@ -277,22 +335,46 @@ func iterateFields(n *schemaNode) { n.arrowField = buildArrowField(n, arrow.StructOf(fields...), md) } -func isLogicalSchemaType(s avro.Schema) bool { - lts, ok := s.(avro.LogicalTypeSchema) - if !ok { - return false +// nullableBranch returns the non-null branch of a two-element ["null", T] +// union, plus true if the union is in that nullable shape. If the union has +// more than two branches or no null branch, ok is false. +// +// Heterogeneous non-nullable unions (e.g. ["null", "int", "string"] or +// ["int", "string"]) are not supported and callers panic on them rather +// than silently picking one arm. +func nullableBranch(s avro.SchemaNode) (avro.SchemaNode, bool) { + if s.Type != "union" || len(s.Branches) < 2 { + return avro.SchemaNode{}, false } - if lts.Logical() != nil { - return true + var nonNull *avro.SchemaNode + for i := range s.Branches { + b := s.Branches[i] + if b.Type == "null" { + continue + } + if nonNull != nil { + return avro.SchemaNode{}, false + } + nonNull = &b } - return false + if nonNull == nil { + return avro.SchemaNode{}, false + } + return *nonNull, true } -func isNamedSchema(s avro.Schema) (string, bool) { - if ns, ok := s.(avro.NamedSchema); ok { - return ns.FullName(), ok +func isLogicalSchemaType(s avro.SchemaNode) bool { + return s.LogicalType != "" +} + +func isNamedSchema(s avro.SchemaNode) (string, bool) { + if s.Name == "" { + return "", false + } + if s.Namespace != "" { + return s.Namespace + "." + s.Name, true } - return "", false + return s.Name, true } func buildArrowField(n *schemaNode, t arrow.DataType, m arrow.Metadata) arrow.Field { @@ -337,7 +419,7 @@ func avroPrimitiveToArrowType(avroFieldType string) arrow.DataType { func avroLogicalToArrowField(n *schemaNode) { var dt arrow.DataType // Avro logical types - switch lt := n.schema.(avro.LogicalTypeSchema).Logical(); lt.Type() { + switch n.node.LogicalType { // The decimal logical type represents an arbitrary-precision signed decimal number of the form unscaled × 10-scale. // A decimal logical type annotates Avro bytes or fixed types. The byte array must contain the two’s-complement // representation of the unscaled integer value in big-endian byte order. The scale is fixed, and is specified @@ -348,13 +430,13 @@ func avroLogicalToArrowField(n *schemaNode) { // precision, a JSON integer representing the (maximum) precision of decimals stored in this type (required). case "decimal": id := arrow.DECIMAL128 - if lt.(*avro.DecimalLogicalSchema).Precision() > decimal128.MaxPrecision { + if n.node.Precision > decimal128.MaxPrecision { id = arrow.DECIMAL256 } - dt, _ = arrow.NewDecimalType(id, int32(lt.(*avro.DecimalLogicalSchema).Precision()), int32(lt.(*avro.DecimalLogicalSchema).Scale())) + dt, _ = arrow.NewDecimalType(id, int32(n.node.Precision), int32(n.node.Scale)) - // The uuid logical type represents a random generated universally unique identifier (UUID). - // A uuid logical type annotates an Avro string. The string has to conform with RFC-4122 + // The uuid logical type represents a random generated universally unique identifier (UUID). + // A uuid logical type annotates an Avro string. The string has to conform with RFC-4122 case "uuid": dt = extensions.NewUUIDType() @@ -399,21 +481,19 @@ func avroLogicalToArrowField(n *schemaNode) { case "timestamp-micros": dt = arrow.FixedWidthTypes.Timestamp_us - // The local-timestamp-millis logical type represents a timestamp in a local timezone, regardless of - // what specific time zone is considered local, with a precision of one millisecond. - // A local-timestamp-millis logical type annotates an Avro long, where the long stores the number of - // milliseconds, from 1 January 1970 00:00:00.000. - // The local (wall-clock) semantics are preserved by leaving TimeZone unset, distinguishing these from - // the global timestamp-millis/micros types above which carry a UTC zone. + // The timestamp-nanos logical type represents an instant on the global timeline with nanosecond + // precision. twmb/avro decodes it to time.Time (UTC). + case "timestamp-nanos": + dt = arrow.FixedWidthTypes.Timestamp_ns + + // The local-timestamp-millis/micros/nanos logical types represent a timestamp in a local timezone. + // Arrow models that as a TimestampType with no time zone set. case "local-timestamp-millis": dt = &arrow.TimestampType{Unit: arrow.Millisecond} - - // The local-timestamp-micros logical type represents a timestamp in a local timezone, regardless of - // what specific time zone is considered local, with a precision of one microsecond. - // A local-timestamp-micros logical type annotates an Avro long, where the long stores the number of - // microseconds, from 1 January 1970 00:00:00.000000. case "local-timestamp-micros": dt = &arrow.TimestampType{Unit: arrow.Microsecond} + case "local-timestamp-nanos": + dt = &arrow.TimestampType{Unit: arrow.Nanosecond} // The duration logical type represents an amount of time defined by a number of months, days and milliseconds. // This is not equivalent to a number of milliseconds, because, depending on the moment in time from which the diff --git a/arrow/avro/schema_test.go b/arrow/avro/schema_test.go index 689ba3c43..bbfde5954 100644 --- a/arrow/avro/schema_test.go +++ b/arrow/avro/schema_test.go @@ -18,12 +18,12 @@ package avro import ( "fmt" - "strings" "testing" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/avro/testdata" - avropkg "github.com/hamba/avro/v2" + "github.com/apache/arrow-go/v18/arrow/extensions" + hambaAvro "github.com/hamba/avro/v2" ) func TestSchemaStringEqual(t *testing.T) { @@ -128,6 +128,10 @@ func TestSchemaStringEqual(t *testing.T) { Name: "uuidField", Type: arrow.BinaryTypes.String, }, + { + Name: "fixedUuidField", + Type: extensions.NewUUIDType(), + }, { Name: "timemillis", Type: arrow.FixedWidthTypes.Time32ms, @@ -172,7 +176,7 @@ func TestSchemaStringEqual(t *testing.T) { if err != nil { t.Fatalf("%v", err) } - got, err := ArrowSchemaFromAvro(schema) + got, err := ArrowSchemaFromAvroJSON(schema) if err != nil { t.Fatalf("%v", err) } @@ -185,25 +189,36 @@ func TestSchemaStringEqual(t *testing.T) { } } -func TestComplexUnionReportsError(t *testing.T) { - // Non-nullable union (e.g. [int, string]) is not supported and should - // produce a clear error rather than being silently dropped. - const avroSchemaJSON = `{ +// Remove together with [ArrowSchemaFromAvro] at the next major release. +func TestArrowSchemaFromAvro_Deprecated_PreservesLogicalTypesOnFixed(t *testing.T) { + const schemaJSON = `{ "type": "record", - "name": "WithComplexUnion", + "name": "Sample", "fields": [ - {"name": "value", "type": ["int", "string"]} + {"name": "id", "type": "int"}, + {"name": "name", "type": "string"}, + {"name": "nullable_double", "type": ["null", "double"]}, + {"name": "uuid_string", "type": {"type": "string", "logicalType": "uuid"}}, + {"name": "ts_millis", "type": {"type": "long", "logicalType": "timestamp-millis"}}, + {"name": "fixed_uuid", "type": {"type": "fixed", "name": "FUUID", "size": 16, "logicalType": "uuid"}}, + {"name": "fixed_decimal", "type": {"type": "fixed", "name": "FDec", "size": 16, "logicalType": "decimal", "precision": 20, "scale": 4}}, + {"name": "fixed_duration", "type": {"type": "fixed", "name": "FDur", "size": 12, "logicalType": "duration"}} ] }` - schema, err := avropkg.Parse(avroSchemaJSON) + hambaSchema, err := hambaAvro.Parse(schemaJSON) + if err != nil { + t.Fatalf("hamba parse: %v", err) + } + + got, err := ArrowSchemaFromAvro(hambaSchema) if err != nil { - t.Fatalf("avro parse: %v", err) + t.Fatalf("ArrowSchemaFromAvro: %v", err) } - got, err := ArrowSchemaFromAvro(schema) - if err == nil { - t.Fatalf("expected error for complex union, got schema=%v", got) + want, err := ArrowSchemaFromAvroJSON(schemaJSON) + if err != nil { + t.Fatalf("ArrowSchemaFromAvroJSON: %v", err) } - if !strings.Contains(err.Error(), "union") { - t.Fatalf("expected error to mention union, got: %v", err) + if got.String() != want.String() { + t.Fatalf("schema mismatch:\n got = %s\nwant = %s", got.String(), want.String()) } } diff --git a/arrow/avro/testdata/alltypes.avsc b/arrow/avro/testdata/alltypes.avsc index 27db841df..1da7332ea 100644 --- a/arrow/avro/testdata/alltypes.avsc +++ b/arrow/avro/testdata/alltypes.avsc @@ -164,6 +164,15 @@ "name": "uuidField", "type": "string" }, + { + "name": "fixedUuidField", + "type": { + "type": "fixed", + "name": "FixedUUID", + "size": 16, + "logicalType": "uuid" + } + }, { "name": "timemillis", "type": { diff --git a/arrow/avro/testdata/testdata.go b/arrow/avro/testdata/testdata.go index 4c8bac71d..fa0466763 100644 --- a/arrow/avro/testdata/testdata.go +++ b/arrow/avro/testdata/testdata.go @@ -17,8 +17,6 @@ package testdata import ( - "encoding/base64" - "encoding/binary" "encoding/json" "fmt" "log" @@ -28,8 +26,9 @@ import ( "strings" "time" - avro "github.com/hamba/avro/v2" - "github.com/hamba/avro/v2/ocf" + "github.com/google/uuid" + "github.com/twmb/avro" + "github.com/twmb/avro/ocf" ) const ( @@ -42,107 +41,175 @@ const ( type ByteArray []byte func (b ByteArray) MarshalJSON() ([]byte, error) { - return json.Marshal(base64.StdEncoding.EncodeToString(b)) + return json.Marshal([]byte(b)) } -type TimestampMicros int64 +type TimestampJSON time.Time -func (t TimestampMicros) MarshalJSON() ([]byte, error) { - ts := time.Unix(0, int64(t)*int64(time.Microsecond)).UTC().Format(time.RFC3339Nano) - return json.Marshal(ts) +func (t TimestampJSON) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Time(t).UTC().Format(time.RFC3339Nano)) } -type TimestampMillis int64 +type TimeMillisJSON time.Duration -func (t TimestampMillis) MarshalJSON() ([]byte, error) { - ts := time.Unix(0, int64(t)*int64(time.Millisecond)).UTC().Format(time.RFC3339Nano) - return json.Marshal(ts) -} - -type TimeMillis time.Duration - -func (t TimeMillis) MarshalJSON() ([]byte, error) { +func (t TimeMillisJSON) MarshalJSON() ([]byte, error) { ts := time.Unix(0, int64(t)).UTC().Format("15:04:05.000") return json.Marshal(strings.TrimRight(ts, "0.")) } -type TimeMicros time.Duration +type TimeMicrosJSON time.Duration -func (t TimeMicros) MarshalJSON() ([]byte, error) { +func (t TimeMicrosJSON) MarshalJSON() ([]byte, error) { ts := time.Unix(0, int64(t)).UTC().Format("15:04:05.000000") return json.Marshal(strings.TrimRight(ts, "0.")) } -type ExplicitNamespace [12]byte +type FixedJSON []byte -func (t ExplicitNamespace) MarshalJSON() ([]byte, error) { - return json.Marshal(t[:]) +func (t FixedJSON) MarshalJSON() ([]byte, error) { + return json.Marshal([]byte(t)) } -type MD5 [16]byte +type FixedUUIDJSON [16]byte -func (t MD5) MarshalJSON() ([]byte, error) { - return json.Marshal(t[:]) +func (t FixedUUIDJSON) MarshalJSON() ([]byte, error) { + return json.Marshal(uuid.UUID(t).String()) } -type DecimalType []byte +type DecimalJSON struct { + Rat *big.Rat +} -func (t DecimalType) MarshalJSON() ([]byte, error) { - v := new(big.Int).SetBytes(t) - s := fmt.Sprintf("%0*s", decimalTypeScale+1, v.String()) +func (t DecimalJSON) MarshalJSON() ([]byte, error) { + num := new(big.Int).Set(t.Rat.Num()) + den := new(big.Int).Set(t.Rat.Denom()) + scaleFactor := new(big.Int).Exp(big.NewInt(10), big.NewInt(decimalTypeScale), nil) + num.Mul(num, scaleFactor) + num.Quo(num, den) + s := fmt.Sprintf("%0*s", decimalTypeScale+1, num.String()) point := len(s) - decimalTypeScale return json.Marshal(s[:point] + "." + s[point:]) } -type Duration [12]byte - -func (t Duration) MarshalJSON() ([]byte, error) { - milliseconds := int32(binary.LittleEndian.Uint32(t[8:12])) +type DurationJSON avro.Duration - m := map[string]interface{}{ - "months": int32(binary.LittleEndian.Uint32(t[0:4])), - "days": int32(binary.LittleEndian.Uint32(t[4:8])), - "nanoseconds": int64(milliseconds) * int64(time.Millisecond), +func (t DurationJSON) MarshalJSON() ([]byte, error) { + m := map[string]any{ + "months": int32(t.Months), + "days": int32(t.Days), + "nanoseconds": int64(t.Milliseconds) * int64(time.Millisecond), } return json.Marshal(m) } -type Date int32 +type DateJSON time.Time -func (t Date) MarshalJSON() ([]byte, error) { - v := time.Unix(int64(t)*86400, 0).UTC().Format("2006-01-02") - return json.Marshal(v) +func (t DateJSON) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Time(t).UTC().Format("2006-01-02")) } type Example struct { - InheritNull string `avro:"inheritNull" json:"inheritNull"` - ExplicitNamespace ExplicitNamespace `avro:"explicitNamespace" json:"explicitNamespace"` - FullName FullNameData `avro:"fullName" json:"fullName"` - ID int32 `avro:"id" json:"id"` - BigID int64 `avro:"bigId" json:"bigId"` - Temperature *float32 `avro:"temperature" json:"temperature"` - Fraction *float64 `avro:"fraction" json:"fraction"` - IsEmergency bool `avro:"is_emergency" json:"is_emergency"` - RemoteIP *ByteArray `avro:"remote_ip" json:"remote_ip"` - NullableRemoteIPS *[]ByteArray `avro:"nullable_remote_ips" json:"nullable_remote_ips"` - Person PersonData `avro:"person" json:"person"` - DecimalField DecimalType `avro:"decimalField" json:"decimalField"` - Decimal256Field DecimalType `avro:"decimal256Field" json:"decimal256Field"` - UUIDField string `avro:"uuidField" json:"uuidField"` - TimeMillis TimeMillis `avro:"timemillis" json:"timemillis"` - TimeMicros TimeMicros `avro:"timemicros" json:"timemicros"` - TimestampMillis TimestampMillis `avro:"timestampmillis" json:"timestampmillis"` - TimestampMicros TimestampMicros `avro:"timestampmicros" json:"timestampmicros"` - LocalTSMillis TimestampMillis `avro:"localtimestampmillis" json:"localtimestampmillis"` - LocalTSMicros TimestampMicros `avro:"localtimestampmicros" json:"localtimestampmicros"` - Duration Duration `avro:"duration" json:"duration"` - Date Date `avro:"date" json:"date"` + InheritNull string `avro:"inheritNull"` + ExplicitNamespace [12]byte `avro:"explicitNamespace"` + FullName FullNameData `avro:"fullName"` + ID int32 `avro:"id"` + BigID int64 `avro:"bigId"` + Temperature *float32 `avro:"temperature"` + Fraction *float64 `avro:"fraction"` + IsEmergency bool `avro:"is_emergency"` + RemoteIP *[]byte `avro:"remote_ip"` + NullableRemoteIPS *[][]byte `avro:"nullable_remote_ips"` + Person PersonData `avro:"person"` + DecimalField *big.Rat `avro:"decimalField"` + Decimal256Field *big.Rat `avro:"decimal256Field"` + UUIDField string `avro:"uuidField"` + FixedUUIDField [16]byte `avro:"fixedUuidField"` + TimeMillis time.Duration `avro:"timemillis"` + TimeMicros time.Duration `avro:"timemicros"` + TimestampMillis time.Time `avro:"timestampmillis"` + TimestampMicros time.Time `avro:"timestampmicros"` + LocalTSMillis time.Time `avro:"localtimestampmillis"` + LocalTSMicros time.Time `avro:"localtimestampmicros"` + Duration avro.Duration `avro:"duration"` + Date time.Time `avro:"date"` +} + +func (e Example) MarshalJSON() ([]byte, error) { + var remoteIP *ByteArray + if e.RemoteIP != nil { + v := ByteArray(*e.RemoteIP) + remoteIP = &v + } + var nullableRemoteIPs *[]ByteArray + if e.NullableRemoteIPS != nil { + arr := make([]ByteArray, len(*e.NullableRemoteIPS)) + for i, b := range *e.NullableRemoteIPS { + arr[i] = ByteArray(b) + } + nullableRemoteIPs = &arr + } + out := struct { + InheritNull string `json:"inheritNull"` + ExplicitNamespace FixedJSON `json:"explicitNamespace"` + FullName fullNameJSON `json:"fullName"` + ID int32 `json:"id"` + BigID int64 `json:"bigId"` + Temperature *float32 `json:"temperature"` + Fraction *float64 `json:"fraction"` + IsEmergency bool `json:"is_emergency"` + RemoteIP *ByteArray `json:"remote_ip"` + NullableRemoteIPS *[]ByteArray `json:"nullable_remote_ips"` + Person PersonData `json:"person"` + DecimalField DecimalJSON `json:"decimalField"` + Decimal256Field DecimalJSON `json:"decimal256Field"` + UUIDField string `json:"uuidField"` + FixedUUIDField FixedUUIDJSON `json:"fixedUuidField"` + TimeMillis TimeMillisJSON `json:"timemillis"` + TimeMicros TimeMicrosJSON `json:"timemicros"` + TimestampMillis TimestampJSON `json:"timestampmillis"` + TimestampMicros TimestampJSON `json:"timestampmicros"` + LocalTSMillis TimestampJSON `json:"localtimestampmillis"` + LocalTSMicros TimestampJSON `json:"localtimestampmicros"` + Duration DurationJSON `json:"duration"` + Date DateJSON `json:"date"` + }{ + InheritNull: e.InheritNull, + ExplicitNamespace: FixedJSON(e.ExplicitNamespace[:]), + FullName: fullNameJSON{InheritNamespace: e.FullName.InheritNamespace, Md5: FixedJSON(e.FullName.Md5[:])}, + ID: e.ID, + BigID: e.BigID, + Temperature: e.Temperature, + Fraction: e.Fraction, + IsEmergency: e.IsEmergency, + RemoteIP: remoteIP, + NullableRemoteIPS: nullableRemoteIPs, + Person: e.Person, + DecimalField: DecimalJSON{Rat: e.DecimalField}, + Decimal256Field: DecimalJSON{Rat: e.Decimal256Field}, + UUIDField: e.UUIDField, + FixedUUIDField: FixedUUIDJSON(e.FixedUUIDField), + TimeMillis: TimeMillisJSON(e.TimeMillis), + TimeMicros: TimeMicrosJSON(e.TimeMicros), + TimestampMillis: TimestampJSON(e.TimestampMillis), + TimestampMicros: TimestampJSON(e.TimestampMicros), + LocalTSMillis: TimestampJSON(e.LocalTSMillis), + LocalTSMicros: TimestampJSON(e.LocalTSMicros), + Duration: DurationJSON(e.Duration), + Date: DateJSON(e.Date), + } + return json.Marshal(out) } type FullNameData struct { - InheritNamespace string `avro:"inheritNamespace" json:"inheritNamespace"` - Md5 MD5 `avro:"md5" json:"md5"` + InheritNamespace string `avro:"inheritNamespace"` + Md5 [16]byte `avro:"md5"` } + +type fullNameJSON struct { + InheritNamespace string `json:"inheritNamespace"` + Md5 FixedJSON `json:"md5"` +} + type MapField map[string]int64 func (t MapField) MarshalJSON() ([]byte, error) { @@ -199,29 +266,43 @@ func TestdataDir() string { return "" } -func AllTypesAvroSchema() (avro.Schema, error) { +// AllTypesAvroSchema returns the raw JSON of the bundled `alltypes.avsc` +// testdata schema. +func AllTypesAvroSchema() (string, error) { sp := filepath.Join(TestdataDir(), SchemaFileName) avroSchemaBytes, err := os.ReadFile(sp) if err != nil { - return nil, err + return "", err } - return avro.ParseBytes(avroSchemaBytes) + return string(avroSchemaBytes), nil } func sampleData() Example { + now := time.Now().UTC() + // Truncate to micros so timestamp-millis/-micros round-trip exactly. + tsMillis := now.Truncate(time.Millisecond) + tsMicros := now.Truncate(time.Microsecond) + date := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) + + decimal := new(big.Rat).SetFrac(big.NewInt(9876), big.NewInt(100)) // 98.76 + decimal256, ok := new(big.Rat).SetString("12345678901234567890123456789012345678901234567890123456.78") + if !ok { + log.Fatal("bad decimal256 literal in sampleData") + } + return Example{ InheritNull: "a", - ExplicitNamespace: ExplicitNamespace{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, + ExplicitNamespace: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, FullName: FullNameData{ InheritNamespace: "d", - Md5: MD5{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + Md5: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, }, ID: 42, BigID: 42000000000, Temperature: func() *float32 { v := float32(36.6); return &v }(), Fraction: func() *float64 { v := float64(0.75); return &v }(), IsEmergency: true, - RemoteIP: func() *ByteArray { v := ByteArray{192, 168, 1, 1}; return &v }(), + RemoteIP: func() *[]byte { v := []byte{192, 168, 1, 1}; return &v }(), Person: PersonData{ Lastname: "Doe", Address: AddressUSRecord{ @@ -231,21 +312,18 @@ func sampleData() Example { Mapfield: MapField{"foo": 123}, ArrayField: []string{"one", "two"}, }, - DecimalField: DecimalType{0x00, 0x00, 0x00, 0x00, 0x00, 0x26, 0x94}, - Decimal256Field: DecimalType{ - 0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, - 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, - 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, 0x01, - }, + DecimalField: decimal, + Decimal256Field: decimal256, UUIDField: "123e4567-e89b-12d3-a456-426614174000", - TimeMillis: TimeMillis(50412345 * time.Millisecond), - TimeMicros: TimeMicros(50412345678 * time.Microsecond), - TimestampMillis: TimestampMillis(time.Now().UnixNano() / int64(time.Millisecond)), - TimestampMicros: TimestampMicros(time.Now().UnixNano() / int64(time.Microsecond)), - LocalTSMillis: TimestampMillis(time.Now().UnixNano() / int64(time.Millisecond)), - LocalTSMicros: TimestampMicros(time.Now().UnixNano() / int64(time.Microsecond)), - Duration: Duration{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, - Date: Date(time.Now().Unix() / 86400), + FixedUUIDField: [16]byte{0x55, 0x0e, 0x84, 0x00, 0xe2, 0x9b, 0x41, 0xd4, 0xa7, 0x16, 0x44, 0x66, 0x55, 0x44, 0x00, 0x00}, + TimeMillis: 50412345 * time.Millisecond, + TimeMicros: 50412345678 * time.Microsecond, + TimestampMillis: tsMillis, + TimestampMicros: tsMicros, + LocalTSMillis: tsMillis, + LocalTSMicros: tsMicros, + Duration: avro.Duration{Months: 1, Days: 2, Milliseconds: 3}, + Date: date, } } @@ -256,11 +334,16 @@ func writeOCFSampleData(td string, data Example) string { log.Fatal(err) } defer ocfFile.Close() - schema, err := AllTypesAvroSchema() + schemaJSON, err := AllTypesAvroSchema() + if err != nil { + log.Fatal(err) + } + schema, err := avro.Parse(schemaJSON) if err != nil { log.Fatal(err) } - encoder, err := ocf.NewEncoder(schema.String(), ocfFile) + // Pass the original JSON so logical-type annotations survive in the OCF header. + encoder, err := ocf.NewWriter(ocfFile, schema, ocf.WithSchema(schemaJSON)) if err != nil { log.Fatal(err) } diff --git a/go.mod b/go.mod index 1a664b89d..7762fd070 100644 --- a/go.mod +++ b/go.mod @@ -43,6 +43,7 @@ require ( github.com/substrait-io/substrait-go/v8 v8.1.1 github.com/substrait-io/substrait-protobuf/go v0.85.0 github.com/tidwall/sjson v1.2.5 + github.com/twmb/avro v1.5.0 github.com/zeebo/xxh3 v1.1.0 golang.org/x/exp v0.0.0-20260112195511-716be5621a96 golang.org/x/sync v0.21.0 @@ -67,7 +68,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/goccy/go-yaml v1.17.1 // indirect - github.com/golang/snappy v1.0.0 // indirect github.com/gookit/color v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect diff --git a/go.sum b/go.sum index 6298d7dd7..d2738c6d1 100644 --- a/go.sum +++ b/go.sum @@ -54,8 +54,6 @@ github.com/goccy/go-yaml v1.17.1 h1:LI34wktB2xEE3ONG/2Ar54+/HJVBriAGJ55PHls4YuY= github.com/goccy/go-yaml v1.17.1/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= -github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= -github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v25.12.19+incompatible h1:haMV2JRRJCe1998HeW/p0X9UaMTK6SDo0ffLn2+DbLs= github.com/google/flatbuffers v25.12.19+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -164,6 +162,8 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= +github.com/twmb/avro v1.5.0 h1:9jmbvVQQBcyWHv/6zS+q5+nmASiR8/GwhKF/sU7u71c= +github.com/twmb/avro v1.5.0/go.mod h1:X0fT1dY2xcbV4YuCE4mYro+qljHl4kUF5uA/2z1rgSk= github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= From 6092d579ab203b3236b8bc2a900398d388b94490 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Proch=C3=A1zka?= Date: Sat, 20 Jun 2026 15:39:12 +0200 Subject: [PATCH 2/2] refactor(arrow/avro): address review feedback - Reject ambiguous bare named-type references: key the named-type cache by full name and error when an unqualified reference matches a short name defined under multiple namespaces, instead of resolving to an arbitrary one. - Guard that the top-level schema is a record (OCF requirement) and fail up front otherwise; the parser already guarantees a record has a name. - Simplify the union-branch loop in nullableBranch. - Document the 36-byte (hex-dash UUID text) case in appendUUIDData. - Add tests for the root guard (non-record and parser-rejected empty name) and the ambiguous-reference rejection. Co-Authored-By: Claude Opus 4.8 --- arrow/avro/reader_types.go | 2 + arrow/avro/schema.go | 75 ++++++++++++++++++++++++++--------- arrow/avro/schema_test.go | 81 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 18 deletions(-) diff --git a/arrow/avro/reader_types.go b/arrow/avro/reader_types.go index da13b03d6..52ac50b48 100644 --- a/arrow/avro/reader_types.go +++ b/arrow/avro/reader_types.go @@ -567,6 +567,8 @@ func appendUUIDData(b *extensions.UUIDBuilder, data any, fieldName string) error switch len(dt) { case 16: b.AppendBytes([16]byte(dt)) + // 36 bytes is the canonical hex-dash UUID text form + // (e.g. "550e8400-e29b-41d4-a716-446655440000") arriving as raw bytes. case 36: return b.AppendValueFromString(string(dt)) default: diff --git a/arrow/avro/schema.go b/arrow/avro/schema.go index 13214ca23..4e1e29d75 100644 --- a/arrow/avro/schema.go +++ b/arrow/avro/schema.go @@ -49,12 +49,24 @@ type schemaNode struct { nullable bool childrens []*schemaNode arrowField arrow.Field + // namedCache holds record/enum/fixed definitions keyed by full name + // (namespace.name, or the bare name when no namespace is present). namedCache map[string]avro.SchemaNode - index int32 + // bareToFull maps a short name to its sole full name, and ambiguousBare + // flags short names defined under more than one namespace so that a bare + // reference to them can be rejected instead of silently picking one. + bareToFull map[string]string + ambiguousBare map[string]struct{} + index int32 } func newSchemaNode() *schemaNode { - return &schemaNode{index: -1, namedCache: map[string]avro.SchemaNode{}} + return &schemaNode{ + index: -1, + namedCache: map[string]avro.SchemaNode{}, + bareToFull: map[string]string{}, + ambiguousBare: map[string]struct{}{}, + } } func (node *schemaNode) schemaPath() string { @@ -69,11 +81,13 @@ func (node *schemaNode) schemaPath() string { func (node *schemaNode) newChild(n string, s avro.SchemaNode) *schemaNode { child := &schemaNode{ - name: n, - parent: node, - node: s, - namedCache: node.namedCache, - index: int32(len(node.childrens)), + name: n, + parent: node, + node: s, + namedCache: node.namedCache, + bareToFull: node.bareToFull, + ambiguousBare: node.ambiguousBare, + index: int32(len(node.childrens)), } node.childrens = append(node.childrens, child) return child @@ -81,34 +95,54 @@ func (node *schemaNode) newChild(n string, s avro.SchemaNode) *schemaNode { func (node *schemaNode) children() []*schemaNode { return node.childrens } // rememberNamed adds a record/enum/fixed SchemaNode to the named-type cache -// under both its short name and (if a namespace is present) its full name, -// so later references like {"type": "Address"} or {"type": "ns.Address"} -// resolve back to the original definition. +// keyed by its full name (namespace.name, or the bare name when no namespace +// is present), so a later {"type": "ns.Address"} reference resolves back to the +// original definition. It also records the short name so an unqualified +// {"type": "Address"} reference can resolve when unambiguous; if the same short +// name is defined under more than one namespace it is flagged as ambiguous and +// resolveRef will reject bare references to it. func (node *schemaNode) rememberNamed(s avro.SchemaNode) { if s.Name == "" { return } - node.namedCache[s.Name] = s + full := s.Name if s.Namespace != "" { - node.namedCache[s.Namespace+"."+s.Name] = s + full = s.Namespace + "." + s.Name + } + node.namedCache[full] = s + if existing, ok := node.bareToFull[s.Name]; ok && existing != full { + node.ambiguousBare[s.Name] = struct{}{} + } else { + node.bareToFull[s.Name] = full } } // resolveRef replaces s with its inline definition if s.Type is a named-type // reference rather than a builtin Avro type. atField, when non-empty, names // the field this reference appears in and is included in the panic so the -// user can locate the offending entry. +// user can locate the offending entry. A bare reference whose short name is +// defined under multiple namespaces is rejected rather than resolved to an +// arbitrary definition. func (node *schemaNode) resolveRef(s avro.SchemaNode, atField string) avro.SchemaNode { if _, ok := builtinAvroTypes[s.Type]; ok { return s } - if def, ok := node.namedCache[s.Type]; ok { - return def - } loc := node.schemaPath() if atField != "" { loc += "." + atField } + // An exact full-name match always wins. + if def, ok := node.namedCache[s.Type]; ok { + return def + } + // Otherwise treat s.Type as a short name, resolving it only when it is + // unambiguous across namespaces. + if full, ok := node.bareToFull[s.Type]; ok { + if _, ambiguous := node.ambiguousBare[s.Type]; ambiguous { + panic(fmt.Errorf("ambiguous named type %q referenced at %s: defined in multiple namespaces; use a fully-qualified name", s.Type, loc)) + } + return node.namedCache[full] + } panic(fmt.Errorf("unknown named type %q referenced at %s", s.Type, loc)) } @@ -142,6 +176,12 @@ func arrowSchemaFromAvroInternal(schema *avro.Schema) (s *arrow.Schema, err erro } }() root := schema.Root() + // OCF requires the top-level schema to be a record; reject anything else up + // front instead of producing a degenerate or empty Arrow schema. (The parser + // already guarantees a record has a non-empty name.) + if root.Type != "record" { + panic(fmt.Errorf("avro schema root must be a record, got type %q", root.Type)) + } n := newSchemaNode() n.node = root c := n.newChild(root.Name, root) @@ -347,8 +387,7 @@ func nullableBranch(s avro.SchemaNode) (avro.SchemaNode, bool) { return avro.SchemaNode{}, false } var nonNull *avro.SchemaNode - for i := range s.Branches { - b := s.Branches[i] + for _, b := range s.Branches { if b.Type == "null" { continue } diff --git a/arrow/avro/schema_test.go b/arrow/avro/schema_test.go index bbfde5954..a5a1e3edf 100644 --- a/arrow/avro/schema_test.go +++ b/arrow/avro/schema_test.go @@ -18,12 +18,14 @@ package avro import ( "fmt" + "strings" "testing" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/avro/testdata" "github.com/apache/arrow-go/v18/arrow/extensions" hambaAvro "github.com/hamba/avro/v2" + avro "github.com/twmb/avro" ) func TestSchemaStringEqual(t *testing.T) { @@ -222,3 +224,82 @@ func TestArrowSchemaFromAvro_Deprecated_PreservesLogicalTypesOnFixed(t *testing. t.Fatalf("schema mismatch:\n got = %s\nwant = %s", got.String(), want.String()) } } + +// OCF requires a record at the top level. Non-record roots are rejected by our +// own guard, while a record with an empty name is rejected earlier by the avro +// parser; both must surface an error rather than produce a degenerate schema. +func TestArrowSchemaFromAvroJSON_RejectsInvalidRoot(t *testing.T) { + tests := []struct { + name string + schemaJSON string + wantErr string // caught by our guard + wantParse bool // caught by the avro parser before our guard + }{ + {name: "string root", schemaJSON: `"string"`, wantErr: "must be a record"}, + {name: "array root", schemaJSON: `{"type":"array","items":"int"}`, wantErr: "must be a record"}, + {name: "empty record name", schemaJSON: `{"type":"record","name":"","fields":[{"name":"x","type":"int"}]}`, wantParse: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ArrowSchemaFromAvroJSON(tt.schemaJSON) + if err == nil { + t.Fatalf("expected error for %s", tt.schemaJSON) + } + if tt.wantParse { + // The parser rejects it before our guard runs, so the error is + // not wrapped as an invalid avro schema by arrowSchemaFromAvroInternal. + if strings.Contains(err.Error(), "must be a record") { + t.Fatalf("expected parser error, got our guard's error: %v", err) + } + return + } + if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("unexpected error for %s: %v", tt.schemaJSON, err) + } + }) + } +} + +// A named record referenced again by name resolves back to the same definition. +func TestArrowSchemaFromAvroJSON_ReusedNamedReference(t *testing.T) { + const schemaJSON = `{"type":"record","name":"Root","fields":[ + {"name":"a","type":{"type":"record","name":"Foo","fields":[{"name":"x","type":"int"}]}}, + {"name":"b","type":"Foo"}]}` + s, err := ArrowSchemaFromAvroJSON(schemaJSON) + if err != nil { + t.Fatalf("ArrowSchemaFromAvroJSON: %v", err) + } + if s.NumFields() != 2 { + t.Fatalf("got %d fields, want 2", s.NumFields()) + } + if a, b := s.Field(0).Type.String(), s.Field(1).Type.String(); a != b { + t.Fatalf("reused reference resolved to a different type:\n a = %s\n b = %s", a, b) + } +} + +// Two records sharing a short name across namespaces are kept distinct by full +// name, and an unqualified reference to that short name is rejected instead of +// silently resolving to one of them (restoring hamba's erroring behavior). +func TestResolveRefAmbiguousBareName(t *testing.T) { + n := newSchemaNode() + n.rememberNamed(avro.SchemaNode{Name: "Foo", Namespace: "a"}) + n.rememberNamed(avro.SchemaNode{Name: "Foo", Namespace: "b"}) + + got := n.resolveRef(avro.SchemaNode{Type: "a.Foo"}, "") + if got.Namespace != "a" { + t.Fatalf("a.Foo resolved to namespace %q, want a", got.Namespace) + } + + func() { + defer func() { + r := recover() + if r == nil { + t.Fatalf("expected panic for ambiguous bare reference") + } + if !strings.Contains(fmt.Sprint(r), "ambiguous named type") { + t.Fatalf("unexpected panic: %v", r) + } + }() + n.resolveRef(avro.SchemaNode{Type: "Foo"}, "field") + }() +}