package gorm import ( "errors" "fmt" "reflect" "strconv" "strings" ) // preloadCallback used to preload associations func preloadCallback(scope *Scope) { if _, ok := scope.Get("gorm:auto_preload"); ok { autoPreload(scope) } if scope.Search.preload == nil || scope.HasError() { return } var ( preloadedMap = map[string]bool{} fields = scope.Fields() ) for _, preload := range scope.Search.preload { var ( preloadFields = strings.Split(preload.schema, ".") currentScope = scope currentFields = fields ) for idx, preloadField := range preloadFields { var currentPreloadConditions []interface{} if currentScope == nil { continue } // if not preloaded if preloadKey := strings.Join(preloadFields[:idx+1], "."); !preloadedMap[preloadKey] { // assign search conditions to last preload if idx == len(preloadFields)-1 { currentPreloadConditions = preload.conditions } for _, field := range currentFields { if field.Name != preloadField || field.Relationship == nil { continue } switch field.Relationship.Kind { case "has_one": currentScope.handleHasOnePreload(field, currentPreloadConditions) case "has_many": currentScope.handleHasManyPreload(field, currentPreloadConditions) case "belongs_to": currentScope.handleBelongsToPreload(field, currentPreloadConditions) case "many_to_many": currentScope.handleManyToManyPreload(field, currentPreloadConditions) default: scope.Err(errors.New("unsupported relation")) } preloadedMap[preloadKey] = true break } if !preloadedMap[preloadKey] { scope.Err(fmt.Errorf("can't preload field %s for %s", preloadField, currentScope.GetModelStruct().ModelType)) return } } // preload next level if idx < len(preloadFields)-1 { currentScope = currentScope.getColumnAsScope(preloadField) if currentScope != nil { currentFields = currentScope.Fields() } } } } } func autoPreload(scope *Scope) { for _, field := range scope.Fields() { if field.Relationship == nil { continue } if val, ok := field.TagSettings["PRELOAD"]; ok { if preload, err := strconv.ParseBool(val); err != nil { scope.Err(errors.New("invalid preload option")) return } else if !preload { continue } } scope.Search.Preload(field.Name) } } func (scope *Scope) generatePreloadDBWithConditions(conditions []interface{}) (*DB, []interface{}) { var ( preloadDB = scope.NewDB() preloadConditions []interface{} ) for _, condition := range conditions { if scopes, ok := condition.(func(*DB) *DB); ok { preloadDB = scopes(preloadDB) } else { preloadConditions = append(preloadConditions, condition) } } return preloadDB, preloadConditions } // handleHasOnePreload used to preload has one associations func (scope *Scope) handleHasOnePreload(field *Field, conditions []interface{}) { relation := field.Relationship // get relations's primary keys primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } // preload conditions preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) // find relations query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) values := toQueryValues(primaryKeys) if relation.PolymorphicType != "" { query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) values = append(values, relation.PolymorphicValue) } results := makeSlice(field.Struct.Type) scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) // assign find results var ( resultsValue = indirect(reflect.ValueOf(results)) indirectScopeValue = scope.IndirectValue() ) if indirectScopeValue.Kind() == reflect.Slice { for j := 0; j < indirectScopeValue.Len(); j++ { for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) foreignValues := getValueFromFields(result, relation.ForeignFieldNames) if indirectValue := indirect(indirectScopeValue.Index(j)); equalAsString(getValueFromFields(indirectValue, relation.AssociationForeignFieldNames), foreignValues) { indirectValue.FieldByName(field.Name).Set(result) break } } } } else { for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) scope.Err(field.Set(result)) } } } // handleHasManyPreload used to preload has many associations func (scope *Scope) handleHasManyPreload(field *Field, conditions []interface{}) { relation := field.Relationship // get relations's primary keys primaryKeys := scope.getColumnAsArray(relation.AssociationForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } // preload conditions preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) // find relations query := fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.ForeignDBNames), toQueryMarks(primaryKeys)) values := toQueryValues(primaryKeys) if relation.PolymorphicType != "" { query += fmt.Sprintf(" AND %v = ?", scope.Quote(relation.PolymorphicDBName)) values = append(values, relation.PolymorphicValue) } results := makeSlice(field.Struct.Type) scope.Err(preloadDB.Where(query, values...).Find(results, preloadConditions...).Error) // assign find results var ( resultsValue = indirect(reflect.ValueOf(results)) indirectScopeValue = scope.IndirectValue() ) if indirectScopeValue.Kind() == reflect.Slice { preloadMap := make(map[string][]reflect.Value) for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) foreignValues := getValueFromFields(result, relation.ForeignFieldNames) preloadMap[toString(foreignValues)] = append(preloadMap[toString(foreignValues)], result) } for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) objectRealValue := getValueFromFields(object, relation.AssociationForeignFieldNames) f := object.FieldByName(field.Name) if results, ok := preloadMap[toString(objectRealValue)]; ok { f.Set(reflect.Append(f, results...)) } else { f.Set(reflect.MakeSlice(f.Type(), 0, 0)) } } } else { scope.Err(field.Set(resultsValue)) } } // handleBelongsToPreload used to preload belongs to associations func (scope *Scope) handleBelongsToPreload(field *Field, conditions []interface{}) { relation := field.Relationship // preload conditions preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) // get relations's primary keys primaryKeys := scope.getColumnAsArray(relation.ForeignFieldNames, scope.Value) if len(primaryKeys) == 0 { return } // find relations results := makeSlice(field.Struct.Type) scope.Err(preloadDB.Where(fmt.Sprintf("%v IN (%v)", toQueryCondition(scope, relation.AssociationForeignDBNames), toQueryMarks(primaryKeys)), toQueryValues(primaryKeys)...).Find(results, preloadConditions...).Error) // assign find results var ( resultsValue = indirect(reflect.ValueOf(results)) indirectScopeValue = scope.IndirectValue() ) for i := 0; i < resultsValue.Len(); i++ { result := resultsValue.Index(i) if indirectScopeValue.Kind() == reflect.Slice { value := getValueFromFields(result, relation.AssociationForeignFieldNames) for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) if equalAsString(getValueFromFields(object, relation.ForeignFieldNames), value) { object.FieldByName(field.Name).Set(result) } } } else { scope.Err(field.Set(result)) } } } // handleManyToManyPreload used to preload many to many associations func (scope *Scope) handleManyToManyPreload(field *Field, conditions []interface{}) { var ( relation = field.Relationship joinTableHandler = relation.JoinTableHandler fieldType = field.Struct.Type.Elem() foreignKeyValue interface{} foreignKeyType = reflect.ValueOf(&foreignKeyValue).Type() linkHash = map[string][]reflect.Value{} isPtr bool ) if fieldType.Kind() == reflect.Ptr { isPtr = true fieldType = fieldType.Elem() } var sourceKeys = []string{} for _, key := range joinTableHandler.SourceForeignKeys() { sourceKeys = append(sourceKeys, key.DBName) } // preload conditions preloadDB, preloadConditions := scope.generatePreloadDBWithConditions(conditions) // generate query with join table newScope := scope.New(reflect.New(fieldType).Interface()) preloadDB = preloadDB.Table(newScope.TableName()).Model(newScope.Value).Select("*") preloadDB = joinTableHandler.JoinWith(joinTableHandler, preloadDB, scope.Value) // preload inline conditions if len(preloadConditions) > 0 { preloadDB = preloadDB.Where(preloadConditions[0], preloadConditions[1:]...) } rows, err := preloadDB.Rows() if scope.Err(err) != nil { return } defer rows.Close() columns, _ := rows.Columns() for rows.Next() { var ( elem = reflect.New(fieldType).Elem() fields = scope.New(elem.Addr().Interface()).Fields() ) // register foreign keys in join tables var joinTableFields []*Field for _, sourceKey := range sourceKeys { joinTableFields = append(joinTableFields, &Field{StructField: &StructField{DBName: sourceKey, IsNormal: true}, Field: reflect.New(foreignKeyType).Elem()}) } scope.scan(rows, columns, append(fields, joinTableFields...)) var foreignKeys = make([]interface{}, len(sourceKeys)) // generate hashed forkey keys in join table for idx, joinTableField := range joinTableFields { if !joinTableField.Field.IsNil() { foreignKeys[idx] = joinTableField.Field.Elem().Interface() } } hashedSourceKeys := toString(foreignKeys) if isPtr { linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem.Addr()) } else { linkHash[hashedSourceKeys] = append(linkHash[hashedSourceKeys], elem) } } if err := rows.Err(); err != nil { scope.Err(err) } // assign find results var ( indirectScopeValue = scope.IndirectValue() fieldsSourceMap = map[string][]reflect.Value{} foreignFieldNames = []string{} ) for _, dbName := range relation.ForeignFieldNames { if field, ok := scope.FieldByName(dbName); ok { foreignFieldNames = append(foreignFieldNames, field.Name) } } if indirectScopeValue.Kind() == reflect.Slice { for j := 0; j < indirectScopeValue.Len(); j++ { object := indirect(indirectScopeValue.Index(j)) key := toString(getValueFromFields(object, foreignFieldNames)) fieldsSourceMap[key] = append(fieldsSourceMap[key], object.FieldByName(field.Name)) } } else if indirectScopeValue.IsValid() { key := toString(getValueFromFields(indirectScopeValue, foreignFieldNames)) fieldsSourceMap[key] = append(fieldsSourceMap[key], indirectScopeValue.FieldByName(field.Name)) } for source, link := range linkHash { for i, field := range fieldsSourceMap[source] { //If not 0 this means Value is a pointer and we already added preloaded models to it if fieldsSourceMap[source][i].Len() != 0 { continue } field.Set(reflect.Append(fieldsSourceMap[source][i], link...)) } } }