summaryrefslogtreecommitdiffhomepage
path: root/pkg/state/decode.go
blob: 47e6b878a4ec5b3d096d1f014164f930eefc5924 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package state

import (
	"bytes"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"reflect"
	"sort"

	"github.com/golang/protobuf/proto"
	pb "gvisor.dev/gvisor/pkg/state/object_go_proto"
)

// objectState represents an object that may be in the process of being
// decoded. Specifically, it represents either a decoded object, or an an
// interest in a future object that will be decoded. When that interest is
// registered (via register), the storage for the object will be created, but
// it will not be decoded until the object is encountered in the stream.
type objectState struct {
	// id is the id for this object.
	//
	// If this field is zero, then this is an anonymous (unregistered,
	// non-reference primitive) object. This is immutable.
	id uint64

	// obj is the object. This may or may not be valid yet, depending on
	// whether complete returns true. However, regardless of whether the
	// object is valid, obj contains a final storage location for the
	// object. This is immutable.
	//
	// Note that this must be addressable (obj.Addr() must not panic).
	//
	// The obj passed to the decode methods below will equal this obj only
	// in the case of decoding the top-level object. However, the passed
	// obj may represent individual fields, elements of a slice, etc. that
	// are effectively embedded within the reflect.Value below but with
	// distinct types.
	obj reflect.Value

	// blockedBy is the number of dependencies this object has.
	blockedBy int

	// blocking is a list of the objects blocked by this one.
	blocking []*objectState

	// callbacks is a set of callbacks to execute on load.
	callbacks []func()

	// path is the decoding path to the object.
	path recoverable
}

// complete indicates the object is complete.
func (os *objectState) complete() bool {
	return os.blockedBy == 0 && len(os.callbacks) == 0
}

// checkComplete checks for completion. If the object is complete, pending
// callbacks will be executed and checkComplete will be called on downstream
// objects (those depending on this one).
func (os *objectState) checkComplete(stats *Stats) {
	if os.blockedBy > 0 {
		return
	}
	stats.Start(os.obj)

	// Fire all callbacks.
	for _, fn := range os.callbacks {
		fn()
	}
	os.callbacks = nil

	// Clear all blocked objects.
	for _, other := range os.blocking {
		other.blockedBy--
		other.checkComplete(stats)
	}
	os.blocking = nil
	stats.Done()
}

// waitFor queues a dependency on the given object.
func (os *objectState) waitFor(other *objectState, callback func()) {
	os.blockedBy++
	other.blocking = append(other.blocking, os)
	if callback != nil {
		other.callbacks = append(other.callbacks, callback)
	}
}

// findCycleFor returns when the given object is found in the blocking set.
func (os *objectState) findCycleFor(target *objectState) []*objectState {
	for _, other := range os.blocking {
		if other == target {
			return []*objectState{target}
		} else if childList := other.findCycleFor(target); childList != nil {
			return append(childList, other)
		}
	}
	return nil
}

// findCycle finds a dependency cycle.
func (os *objectState) findCycle() []*objectState {
	return append(os.findCycleFor(os), os)
}

// decodeState is a graph of objects in the process of being decoded.
//
// The decode process involves loading the breadth-first graph generated by
// encode. This graph is read in it's entirety, ensuring that all object
// storage is complete.
//
// As the graph is being serialized, a set of completion callbacks are
// executed. These completion callbacks should form a set of acyclic subgraphs
// over the original one. After decoding is complete, the objects are scanned
// to ensure that all callbacks are executed, otherwise the callback graph was
// not acyclic.
type decodeState struct {
	// objectByID is the set of objects in progress.
	objectsByID map[uint64]*objectState

	// deferred are objects that have been read, by no interest has been
	// registered yet. These will be decoded once interest in registered.
	deferred map[uint64]*pb.Object

	// outstanding is the number of outstanding objects.
	outstanding uint32

	// r is the input stream.
	r io.Reader

	// stats is the passed stats object.
	stats *Stats

	// recoverable is the panic recover facility.
	recoverable
}

// lookup looks up an object in decodeState or returns nil if no such object
// has been previously registered.
func (ds *decodeState) lookup(id uint64) *objectState {
	return ds.objectsByID[id]
}

// wait registers a dependency on an object.
//
// As a special case, we always allow _useable_ references back to the first
// decoding object because it may have fields that are already decoded. We also
// allow trivial self reference, since they can be handled internally.
func (ds *decodeState) wait(waiter *objectState, id uint64, callback func()) {
	switch id {
	case 0:
		// Nil pointer; nothing to wait for.
		fallthrough
	case waiter.id:
		// Trivial self reference.
		fallthrough
	case 1:
		// Root object; see above.
		if callback != nil {
			callback()
		}
		return
	}

	// No nil can be returned here.
	waiter.waitFor(ds.lookup(id), callback)
}

// waitObject notes a blocking relationship.
func (ds *decodeState) waitObject(os *objectState, p *pb.Object, callback func()) {
	if rv, ok := p.Value.(*pb.Object_RefValue); ok {
		// Refs can encode pointers and maps.
		ds.wait(os, rv.RefValue, callback)
	} else if sv, ok := p.Value.(*pb.Object_SliceValue); ok {
		// See decodeObject; we need to wait for the array (if non-nil).
		ds.wait(os, sv.SliceValue.RefValue, callback)
	} else if iv, ok := p.Value.(*pb.Object_InterfaceValue); ok {
		// It's an interface (wait recurisvely).
		ds.waitObject(os, iv.InterfaceValue.Value, callback)
	} else if callback != nil {
		// Nothing to wait for: execute the callback immediately.
		callback()
	}
}

// register registers a decode with a type.
//
// This type is only used to instantiate a new object if it has not been
// registered previously.
func (ds *decodeState) register(id uint64, typ reflect.Type) *objectState {
	os, ok := ds.objectsByID[id]
	if ok {
		return os
	}

	// Record in the object index.
	if typ.Kind() == reflect.Map {
		os = &objectState{id: id, obj: reflect.MakeMap(typ), path: ds.recoverable.copy()}
	} else {
		os = &objectState{id: id, obj: reflect.New(typ).Elem(), path: ds.recoverable.copy()}
	}
	ds.objectsByID[id] = os

	if o, ok := ds.deferred[id]; ok {
		// There is a deferred object.
		delete(ds.deferred, id) // Free memory.
		ds.decodeObject(os, os.obj, o, "", nil)
	} else {
		// There is no deferred object.
		ds.outstanding++
	}

	return os
}

// decodeStruct decodes a struct value.
func (ds *decodeState) decodeStruct(os *objectState, obj reflect.Value, s *pb.Struct) {
	// Set the fields.
	m := Map{newInternalMap(nil, ds, os)}
	defer internalMapPool.Put(m.internalMap)
	for _, field := range s.Fields {
		m.data = append(m.data, entry{
			name:   field.Name,
			object: field.Value,
		})
	}

	// Sort the fields for efficient searching.
	//
	// Technically, these should already appear in sorted order in the
	// state ordering, so this cost is effectively a single scan to ensure
	// that the order is correct.
	if len(m.data) > 1 {
		sort.Slice(m.data, func(i, j int) bool {
			return m.data[i].name < m.data[j].name
		})
	}

	// Invoke the load; this will recursively decode other objects.
	fns, ok := registeredTypes.lookupFns(obj.Addr().Type())
	if ok {
		// Invoke the loader.
		fns.invokeLoad(obj.Addr(), m)
	} else if obj.NumField() == 0 {
		// Allow anonymous empty structs.
		return
	} else {
		// Propagate an error.
		panic(fmt.Errorf("unregistered type %s", obj.Type()))
	}
}

// decodeMap decodes a map value.
func (ds *decodeState) decodeMap(os *objectState, obj reflect.Value, m *pb.Map) {
	if obj.IsNil() {
		obj.Set(reflect.MakeMap(obj.Type()))
	}
	for i := 0; i < len(m.Keys); i++ {
		// Decode the objects.
		kv := reflect.New(obj.Type().Key()).Elem()
		vv := reflect.New(obj.Type().Elem()).Elem()
		ds.decodeObject(os, kv, m.Keys[i], ".(key %d)", i)
		ds.decodeObject(os, vv, m.Values[i], "[%#v]", kv.Interface())
		ds.waitObject(os, m.Keys[i], nil)
		ds.waitObject(os, m.Values[i], nil)

		// Set in the map.
		obj.SetMapIndex(kv, vv)
	}
}

// decodeArray decodes an array value.
func (ds *decodeState) decodeArray(os *objectState, obj reflect.Value, a *pb.Array) {
	if len(a.Contents) != obj.Len() {
		panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", obj.Len(), len(a.Contents)))
	}
	// Decode the contents into the array.
	for i := 0; i < len(a.Contents); i++ {
		ds.decodeObject(os, obj.Index(i), a.Contents[i], "[%d]", i)
		ds.waitObject(os, a.Contents[i], nil)
	}
}

// decodeInterface decodes an interface value.
func (ds *decodeState) decodeInterface(os *objectState, obj reflect.Value, i *pb.Interface) {
	// Is this a nil value?
	if i.Type == "" {
		return // Just leave obj alone.
	}

	// Get the dispatchable type. This may not be used if the given
	// reference has already been resolved, but if not we need to know the
	// type to create.
	t, ok := registeredTypes.lookupType(i.Type)
	if !ok {
		panic(fmt.Errorf("no valid type for %q", i.Type))
	}

	if obj.Kind() != reflect.Map {
		// Set the obj to be the given typed value; this actually sets
		// obj to be a non-zero value -- namely, it inserts type
		// information. There's no need to do this for maps.
		obj.Set(reflect.Zero(t))
	}

	// Decode the dereferenced element; there is no need to wait here, as
	// the interface object shares the current object state.
	ds.decodeObject(os, obj, i.Value, ".(%s)", i.Type)
}

// decodeObject decodes a object value.
func (ds *decodeState) decodeObject(os *objectState, obj reflect.Value, object *pb.Object, format string, param interface{}) {
	ds.push(false, format, param)
	ds.stats.Add(obj)
	ds.stats.Start(obj)

	switch x := object.GetValue().(type) {
	case *pb.Object_BoolValue:
		obj.SetBool(x.BoolValue)
	case *pb.Object_StringValue:
		obj.SetString(string(x.StringValue))
	case *pb.Object_Int64Value:
		obj.SetInt(x.Int64Value)
		if obj.Int() != x.Int64Value {
			panic(fmt.Errorf("signed integer truncated in %v for %s", object, obj.Type()))
		}
	case *pb.Object_Uint64Value:
		obj.SetUint(x.Uint64Value)
		if obj.Uint() != x.Uint64Value {
			panic(fmt.Errorf("unsigned integer truncated in %v for %s", object, obj.Type()))
		}
	case *pb.Object_DoubleValue:
		obj.SetFloat(x.DoubleValue)
		if obj.Float() != x.DoubleValue {
			panic(fmt.Errorf("float truncated in %v for %s", object, obj.Type()))
		}
	case *pb.Object_RefValue:
		// Resolve the pointer itself, even though the object may not
		// be decoded yet. You need to use wait() in order to ensure
		// that is the case. See wait above, and Map.Barrier.
		if id := x.RefValue; id != 0 {
			// Decoding the interface should have imparted type
			// information, so from this point it's safe to resolve
			// and use this dynamic information for actually
			// creating the object in register.
			//
			// (For non-interfaces this is a no-op).
			dyntyp := reflect.TypeOf(obj.Interface())
			if dyntyp.Kind() == reflect.Map {
				// Remove the map object count here to avoid
				// double counting, as this object will be
				// counted again when it gets processed later.
				// We do not add a reference count as the
				// reference is artificial.
				ds.stats.Remove(obj)
				obj.Set(ds.register(id, dyntyp).obj)
			} else if dyntyp.Kind() == reflect.Ptr {
				ds.push(true /* dereference */, "", nil)
				obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
				ds.pop()
			} else {
				obj.Set(ds.register(id, dyntyp.Elem()).obj.Addr())
			}
		} else {
			// We leave obj alone here. That's because if obj
			// represents an interface, it may have been embued
			// with type information in decodeInterface, and we
			// don't want to destroy that information.
		}
	case *pb.Object_SliceValue:
		// It's okay to slice the array here, since the contents will
		// still be provided later on. These semantics are a bit
		// strange but they are handled in the Map.Barrier properly.
		//
		// The special semantics of zero ref apply here too.
		if id := x.SliceValue.RefValue; id != 0 && x.SliceValue.Capacity > 0 {
			v := reflect.ArrayOf(int(x.SliceValue.Capacity), obj.Type().Elem())
			obj.Set(ds.register(id, v).obj.Slice3(0, int(x.SliceValue.Length), int(x.SliceValue.Capacity)))
		}
	case *pb.Object_ArrayValue:
		ds.decodeArray(os, obj, x.ArrayValue)
	case *pb.Object_StructValue:
		ds.decodeStruct(os, obj, x.StructValue)
	case *pb.Object_MapValue:
		ds.decodeMap(os, obj, x.MapValue)
	case *pb.Object_InterfaceValue:
		ds.decodeInterface(os, obj, x.InterfaceValue)
	case *pb.Object_ByteArrayValue:
		copyArray(obj, reflect.ValueOf(x.ByteArrayValue))
	case *pb.Object_Uint16ArrayValue:
		// 16-bit slices are serialized as 32-bit slices.
		// See object.proto for details.
		s := x.Uint16ArrayValue.Values
		t := obj.Slice(0, obj.Len()).Interface().([]uint16)
		if len(t) != len(s) {
			panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
		}
		for i := range s {
			t[i] = uint16(s[i])
		}
	case *pb.Object_Uint32ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Uint32ArrayValue.Values))
	case *pb.Object_Uint64ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Uint64ArrayValue.Values))
	case *pb.Object_UintptrArrayValue:
		copyArray(obj, castSlice(reflect.ValueOf(x.UintptrArrayValue.Values), reflect.TypeOf(uintptr(0))))
	case *pb.Object_Int8ArrayValue:
		copyArray(obj, castSlice(reflect.ValueOf(x.Int8ArrayValue.Values), reflect.TypeOf(int8(0))))
	case *pb.Object_Int16ArrayValue:
		// 16-bit slices are serialized as 32-bit slices.
		// See object.proto for details.
		s := x.Int16ArrayValue.Values
		t := obj.Slice(0, obj.Len()).Interface().([]int16)
		if len(t) != len(s) {
			panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", len(t), len(s)))
		}
		for i := range s {
			t[i] = int16(s[i])
		}
	case *pb.Object_Int32ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Int32ArrayValue.Values))
	case *pb.Object_Int64ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Int64ArrayValue.Values))
	case *pb.Object_BoolArrayValue:
		copyArray(obj, reflect.ValueOf(x.BoolArrayValue.Values))
	case *pb.Object_Float64ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Float64ArrayValue.Values))
	case *pb.Object_Float32ArrayValue:
		copyArray(obj, reflect.ValueOf(x.Float32ArrayValue.Values))
	default:
		// Shoud not happen, not propagated as an error.
		panic(fmt.Sprintf("unknown object %v for %s", object, obj.Type()))
	}

	ds.stats.Done()
	ds.pop()
}

func copyArray(dest reflect.Value, src reflect.Value) {
	if dest.Len() != src.Len() {
		panic(fmt.Errorf("mismatching array length expect=%d, actual=%d", dest.Len(), src.Len()))
	}
	reflect.Copy(dest, castSlice(src, dest.Type().Elem()))
}

// Deserialize deserializes the object state.
//
// This function may panic and should be run in safely().
func (ds *decodeState) Deserialize(obj reflect.Value) {
	ds.objectsByID[1] = &objectState{id: 1, obj: obj, path: ds.recoverable.copy()}
	ds.outstanding = 1 // The root object.

	// Decode all objects in the stream.
	//
	// See above, we never process objects while we have no outstanding
	// interests (other than the very first object).
	for id := uint64(1); ds.outstanding > 0; id++ {
		os := ds.lookup(id)
		ds.stats.Start(os.obj)

		o, err := ds.readObject()
		if err != nil {
			panic(err)
		}

		if os != nil {
			// Decode the object.
			ds.from = &os.path
			ds.decodeObject(os, os.obj, o, "", nil)
			ds.outstanding--
		} else {
			// If an object hasn't had interest registered
			// previously, we deferred decoding until interest is
			// registered.
			ds.deferred[id] = o
		}

		ds.stats.Done()
	}

	// Check the zero-length header at the end.
	length, object, err := ReadHeader(ds.r)
	if err != nil {
		panic(err)
	}
	if length != 0 {
		panic(fmt.Sprintf("expected zero-length terminal, got %d", length))
	}
	if object {
		panic("expected non-object terminal")
	}

	// Check if we have any deferred objects.
	if count := len(ds.deferred); count > 0 {
		// Shoud not happen, not propagated as an error.
		panic(fmt.Sprintf("still have %d deferred objects", count))
	}

	// Scan and fire all callbacks.
	for _, os := range ds.objectsByID {
		os.checkComplete(ds.stats)
	}

	// Check if we have any remaining dependency cycles.
	for _, os := range ds.objectsByID {
		if !os.complete() {
			// This must be the result of a dependency cycle.
			cycle := os.findCycle()
			var buf bytes.Buffer
			buf.WriteString("dependency cycle: {")
			for i, cycleOS := range cycle {
				if i > 0 {
					buf.WriteString(" => ")
				}
				buf.WriteString(fmt.Sprintf("%s", cycleOS.obj.Type()))
			}
			buf.WriteString("}")
			// Panic as an error; propagate to the caller.
			panic(errors.New(string(buf.Bytes())))
		}
	}
}

type byteReader struct {
	io.Reader
}

// ReadByte implements io.ByteReader.
func (br byteReader) ReadByte() (byte, error) {
	var b [1]byte
	n, err := br.Reader.Read(b[:])
	if n > 0 {
		return b[0], nil
	} else if err != nil {
		return 0, err
	} else {
		return 0, io.ErrUnexpectedEOF
	}
}

// ReadHeader reads an object header.
//
// Each object written to the statefile is prefixed with a header. See
// WriteHeader for more information; these functions are exported to allow
// non-state writes to the file to play nice with debugging tools.
func ReadHeader(r io.Reader) (length uint64, object bool, err error) {
	// Read the header.
	length, err = binary.ReadUvarint(byteReader{r})
	if err != nil {
		return
	}

	// Decode whether the object is valid.
	object = length&0x1 != 0
	length = length >> 1
	return
}

// readObject reads an object from the stream.
func (ds *decodeState) readObject() (*pb.Object, error) {
	// Read the header.
	length, object, err := ReadHeader(ds.r)
	if err != nil {
		return nil, err
	}
	if !object {
		return nil, fmt.Errorf("invalid object header")
	}

	// Read the object.
	buf := make([]byte, length)
	for done := 0; done < len(buf); {
		n, err := ds.r.Read(buf[done:])
		done += n
		if n == 0 && err != nil {
			return nil, err
		}
	}

	// Unmarshal.
	obj := new(pb.Object)
	if err := proto.Unmarshal(buf, obj); err != nil {
		return nil, err
	}

	return obj, nil
}