Studio Tagger (#3510)

* Studio image and parent studio support in scene tagger
* Refactor studio backend and add studio tagger
---------
Co-authored-by: WithoutPants <53250216+WithoutPants@users.noreply.github.com>
This commit is contained in:
Flashy78
2023-07-30 16:50:24 -07:00
committed by GitHub
parent d48dbeb864
commit a665a56ef0
79 changed files with 5224 additions and 1039 deletions

View File

@@ -119,7 +119,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {

View File

@@ -152,7 +152,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {

View File

@@ -58,13 +58,13 @@ func (_m *StudioReaderWriter) Count(ctx context.Context) (int, error) {
return r0, r1
}
// Create provides a mock function with given fields: ctx, newStudio
func (_m *StudioReaderWriter) Create(ctx context.Context, newStudio *models.Studio) error {
ret := _m.Called(ctx, newStudio)
// Create provides a mock function with given fields: ctx, input
func (_m *StudioReaderWriter) Create(ctx context.Context, input *models.Studio) error {
ret := _m.Called(ctx, input)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *models.Studio) error); ok {
r0 = rf(ctx, newStudio)
r0 = rf(ctx, input)
} else {
r0 = ret.Error(0)
}
@@ -155,6 +155,29 @@ func (_m *StudioReaderWriter) FindByStashID(ctx context.Context, stashID models.
return r0, r1
}
// FindByStashIDStatus provides a mock function with given fields: ctx, hasStashID, stashboxEndpoint
func (_m *StudioReaderWriter) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Studio, error) {
ret := _m.Called(ctx, hasStashID, stashboxEndpoint)
var r0 []*models.Studio
if rf, ok := ret.Get(0).(func(context.Context, bool, string) []*models.Studio); ok {
r0 = rf(ctx, hasStashID, stashboxEndpoint)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*models.Studio)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, bool, string) error); ok {
r1 = rf(ctx, hasStashID, stashboxEndpoint)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// FindChildren provides a mock function with given fields: ctx, id
func (_m *StudioReaderWriter) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
ret := _m.Called(ctx, id)
@@ -201,13 +224,13 @@ func (_m *StudioReaderWriter) FindMany(ctx context.Context, ids []int) ([]*model
return r0, r1
}
// GetAliases provides a mock function with given fields: ctx, studioID
func (_m *StudioReaderWriter) GetAliases(ctx context.Context, studioID int) ([]string, error) {
ret := _m.Called(ctx, studioID)
// GetAliases provides a mock function with given fields: ctx, relatedID
func (_m *StudioReaderWriter) GetAliases(ctx context.Context, relatedID int) ([]string, error) {
ret := _m.Called(ctx, relatedID)
var r0 []string
if rf, ok := ret.Get(0).(func(context.Context, int) []string); ok {
r0 = rf(ctx, studioID)
r0 = rf(ctx, relatedID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]string)
@@ -216,7 +239,7 @@ func (_m *StudioReaderWriter) GetAliases(ctx context.Context, studioID int) ([]s
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
r1 = rf(ctx, studioID)
r1 = rf(ctx, relatedID)
} else {
r1 = ret.Error(1)
}
@@ -358,20 +381,6 @@ func (_m *StudioReaderWriter) Update(ctx context.Context, updatedStudio *models.
return r0
}
// UpdateAliases provides a mock function with given fields: ctx, studioID, aliases
func (_m *StudioReaderWriter) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
ret := _m.Called(ctx, studioID, aliases)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int, []string) error); ok {
r0 = rf(ctx, studioID, aliases)
} else {
r0 = ret.Error(0)
}
return r0
}
// UpdateImage provides a mock function with given fields: ctx, studioID, image
func (_m *StudioReaderWriter) UpdateImage(ctx context.Context, studioID int, image []byte) error {
ret := _m.Called(ctx, studioID, image)
@@ -386,13 +395,13 @@ func (_m *StudioReaderWriter) UpdateImage(ctx context.Context, studioID int, ima
return r0
}
// UpdatePartial provides a mock function with given fields: ctx, id, updatedStudio
func (_m *StudioReaderWriter) UpdatePartial(ctx context.Context, id int, updatedStudio models.StudioPartial) (*models.Studio, error) {
ret := _m.Called(ctx, id, updatedStudio)
// UpdatePartial provides a mock function with given fields: ctx, input
func (_m *StudioReaderWriter) UpdatePartial(ctx context.Context, input models.StudioPartial) (*models.Studio, error) {
ret := _m.Called(ctx, input)
var r0 *models.Studio
if rf, ok := ret.Get(0).(func(context.Context, int, models.StudioPartial) *models.Studio); ok {
r0 = rf(ctx, id, updatedStudio)
if rf, ok := ret.Get(0).(func(context.Context, models.StudioPartial) *models.Studio); ok {
r0 = rf(ctx, input)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*models.Studio)
@@ -400,25 +409,11 @@ func (_m *StudioReaderWriter) UpdatePartial(ctx context.Context, id int, updated
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int, models.StudioPartial) error); ok {
r1 = rf(ctx, id, updatedStudio)
if rf, ok := ret.Get(1).(func(context.Context, models.StudioPartial) error); ok {
r1 = rf(ctx, input)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// UpdateStashIDs provides a mock function with given fields: ctx, studioID, stashIDs
func (_m *StudioReaderWriter) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
ret := _m.Called(ctx, studioID, stashIDs)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, int, []models.StashID) error); ok {
r0 = rf(ctx, studioID, stashIDs)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@@ -1,16 +1,108 @@
package models
import (
"context"
"strconv"
"time"
"github.com/stashapp/stash/pkg/utils"
)
type ScrapedStudio struct {
// Set if studio matched
StoredID *string `json:"stored_id"`
Name string `json:"name"`
URL *string `json:"url"`
Image *string `json:"image"`
RemoteSiteID *string `json:"remote_site_id"`
StoredID *string `json:"stored_id"`
Name string `json:"name"`
URL *string `json:"url"`
Parent *ScrapedStudio `json:"parent"`
Image *string `json:"image"`
Images []string `json:"images"`
RemoteSiteID *string `json:"remote_site_id"`
}
func (ScrapedStudio) IsScrapedContent() {}
func (s *ScrapedStudio) ToStudio(endpoint string, excluded map[string]bool) *Studio {
now := time.Now()
// Populate a new studio from the input
newStudio := Studio{
Name: s.Name,
StashIDs: NewRelatedStashIDs([]StashID{
{
Endpoint: endpoint,
StashID: *s.RemoteSiteID,
},
}),
CreatedAt: now,
UpdatedAt: now,
}
if s.URL != nil && !excluded["url"] {
newStudio.URL = *s.URL
}
if s.Parent != nil && s.Parent.StoredID != nil && !excluded["parent"] {
parentId, _ := strconv.Atoi(*s.Parent.StoredID)
newStudio.ParentID = &parentId
}
return &newStudio
}
func (s *ScrapedStudio) GetImage(ctx context.Context, excluded map[string]bool) ([]byte, error) {
// Process the base 64 encoded image string
if len(s.Images) > 0 && !excluded["image"] {
var err error
img, err := utils.ProcessImageInput(ctx, *s.Image)
if err != nil {
return nil, err
}
return img, nil
}
return nil, nil
}
func (s *ScrapedStudio) ToPartial(id *string, endpoint string, excluded map[string]bool, existingStashIDs []StashID) *StudioPartial {
partial := StudioPartial{
UpdatedAt: NewOptionalTime(time.Now()),
}
partial.ID, _ = strconv.Atoi(*id)
if s.Name != "" && !excluded["name"] {
partial.Name = NewOptionalString(s.Name)
}
if s.URL != nil && !excluded["url"] {
partial.URL = NewOptionalString(*s.URL)
}
if s.Parent != nil && !excluded["parent"] {
if s.Parent.StoredID != nil {
parentID, _ := strconv.Atoi(*s.Parent.StoredID)
if parentID > 0 {
// This is to be set directly as we know it has a value and the translator won't have the field
partial.ParentID = NewOptionalInt(parentID)
}
}
} else {
partial.ParentID = NewOptionalIntPtr(nil)
}
partial.StashIDs = &UpdateStashIDs{
StashIDs: existingStashIDs,
Mode: RelationshipUpdateModeSet,
}
partial.StashIDs.Set(StashID{
Endpoint: endpoint,
StashID: *s.RemoteSiteID,
})
return &partial
}
// A performer from a scraping operation...
type ScrapedPerformer struct {
// Set if performer matched

View File

@@ -0,0 +1,65 @@
package models
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_scrapedToStudioInput(t *testing.T) {
const name = "name"
url := "url"
remoteSiteID := "remoteSiteID"
tests := []struct {
name string
studio *ScrapedStudio
want *Studio
}{
{
"set all",
&ScrapedStudio{
Name: name,
URL: &url,
RemoteSiteID: &remoteSiteID,
},
&Studio{
Name: name,
URL: url,
StashIDs: NewRelatedStashIDs([]StashID{
{
StashID: remoteSiteID,
},
}),
},
},
{
"set none",
&ScrapedStudio{
Name: name,
RemoteSiteID: &remoteSiteID,
},
&Studio{
Name: name,
StashIDs: NewRelatedStashIDs([]StashID{
{
StashID: remoteSiteID,
},
}),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.studio.ToStudio("", nil)
assert.NotEqual(t, time.Time{}, got.CreatedAt)
assert.NotEqual(t, time.Time{}, got.UpdatedAt)
got.CreatedAt = time.Time{}
got.UpdatedAt = time.Time{}
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -1,6 +1,7 @@
package models
import (
"context"
"time"
)
@@ -15,34 +16,50 @@ type Studio struct {
Rating *int `json:"rating"`
Details string `json:"details"`
IgnoreAutoTag bool `json:"ignore_auto_tag"`
Aliases RelatedStrings `json:"aliases"`
StashIDs RelatedStashIDs `json:"stash_ids"`
}
func (s *Studio) LoadAliases(ctx context.Context, l AliasLoader) error {
return s.Aliases.load(func() ([]string, error) {
return l.GetAliases(ctx, s.ID)
})
}
func (s *Studio) LoadStashIDs(ctx context.Context, l StashIDLoader) error {
return s.StashIDs.load(func() ([]StashID, error) {
return l.GetStashIDs(ctx, s.ID)
})
}
func (s *Studio) LoadRelationships(ctx context.Context, l PerformerReader) error {
if err := s.LoadAliases(ctx, l); err != nil {
return err
}
if err := s.LoadStashIDs(ctx, l); err != nil {
return err
}
return nil
}
// StudioPartial represents part of a Studio object. It is used to update the database entry.
type StudioPartial struct {
Name OptionalString
URL OptionalString
ParentID OptionalInt
CreatedAt OptionalTime
UpdatedAt OptionalTime
ID int
Name OptionalString
URL OptionalString
ParentID OptionalInt
// Rating expressed in 1-100 scale
Rating OptionalInt
Details OptionalString
CreatedAt OptionalTime
UpdatedAt OptionalTime
IgnoreAutoTag OptionalBool
}
func NewStudio(name string) *Studio {
currentTime := time.Now()
return &Studio{
Name: name,
CreatedAt: currentTime,
UpdatedAt: currentTime,
}
}
func NewStudioPartial() StudioPartial {
updatedTime := time.Now()
return StudioPartial{
UpdatedAt: NewOptionalTime(updatedTime),
}
Aliases *UpdateStrings
StashIDs *UpdateStashIDs
}
type Studios []*Studio

View File

@@ -48,6 +48,7 @@ type StudioReader interface {
FindChildren(ctx context.Context, id int) ([]*Studio, error)
FindByName(ctx context.Context, name string, nocase bool) (*Studio, error)
FindByStashID(ctx context.Context, stashID StashID) ([]*Studio, error)
FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*Studio, error)
Count(ctx context.Context) (int, error)
All(ctx context.Context) ([]*Studio, error)
// TODO - this interface is temporary until the filter schema can fully
@@ -56,18 +57,16 @@ type StudioReader interface {
Query(ctx context.Context, studioFilter *StudioFilterType, findFilter *FindFilterType) ([]*Studio, int, error)
GetImage(ctx context.Context, studioID int) ([]byte, error)
HasImage(ctx context.Context, studioID int) (bool, error)
AliasLoader
StashIDLoader
GetAliases(ctx context.Context, studioID int) ([]string, error)
}
type StudioWriter interface {
Create(ctx context.Context, newStudio *Studio) error
UpdatePartial(ctx context.Context, id int, updatedStudio StudioPartial) (*Studio, error)
UpdatePartial(ctx context.Context, input StudioPartial) (*Studio, error)
Update(ctx context.Context, updatedStudio *Studio) error
Destroy(ctx context.Context, id int) error
UpdateImage(ctx context.Context, studioID int, image []byte) error
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []StashID) error
UpdateAliases(ctx context.Context, studioID int, aliases []string) error
}
type StudioReaderWriter interface {

View File

@@ -5,6 +5,7 @@ import (
"io"
"strconv"
"github.com/stashapp/stash/pkg/sliceutil"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
)
@@ -94,16 +95,7 @@ func (u *UpdateIDs) EffectiveIDs(existing []int) []int {
return nil
}
switch u.Mode {
case RelationshipUpdateModeAdd:
return intslice.IntAppendUniques(existing, u.IDs)
case RelationshipUpdateModeRemove:
return intslice.IntExclude(existing, u.IDs)
case RelationshipUpdateModeSet:
return u.IDs
}
return nil
return effectiveValues(u.IDs, u.Mode, existing)
}
type UpdateStrings struct {
@@ -118,3 +110,26 @@ func (u *UpdateStrings) Strings() []string {
return u.Values
}
// GetEffectiveIDs returns the new IDs that will be effective after the update.
func (u *UpdateStrings) EffectiveValues(existing []string) []string {
if u == nil {
return nil
}
return effectiveValues(u.Values, u.Mode, existing)
}
// effectiveValues returns the new values that will be effective after the update.
func effectiveValues[T comparable](values []T, mode RelationshipUpdateMode, existing []T) []T {
switch mode {
case RelationshipUpdateModeAdd:
return sliceutil.AppendUniques(existing, values)
case RelationshipUpdateModeRemove:
return sliceutil.Exclude(existing, values)
case RelationshipUpdateModeSet:
return values
}
return nil
}

View File

@@ -116,7 +116,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {

View File

@@ -176,7 +176,9 @@ func (i *Importer) populateStudio(ctx context.Context) error {
}
func (i *Importer) createStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.StudioWriter.Create(ctx, newStudio)
if err != nil {

View File

@@ -17,6 +17,7 @@ type StashBoxGraphQLClient interface {
SearchPerformer(ctx context.Context, term string, httpRequestOptions ...client.HTTPRequestOption) (*SearchPerformer, error)
FindPerformerByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindPerformerByID, error)
FindSceneByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByID, error)
FindStudio(ctx context.Context, id *string, name *string, httpRequestOptions ...client.HTTPRequestOption) (*FindStudio, error)
SubmitFingerprint(ctx context.Context, input FingerprintSubmission, httpRequestOptions ...client.HTTPRequestOption) (*SubmitFingerprint, error)
Me(ctx context.Context, httpRequestOptions ...client.HTTPRequestOption) (*Me, error)
SubmitSceneDraft(ctx context.Context, input SceneDraftInput, httpRequestOptions ...client.HTTPRequestOption) (*SubmitSceneDraft, error)
@@ -125,9 +126,13 @@ type ImageFragment struct {
Height int "json:\"height\" graphql:\"height\""
}
type StudioFragment struct {
Name string "json:\"name\" graphql:\"name\""
ID string "json:\"id\" graphql:\"id\""
Urls []*URLFragment "json:\"urls\" graphql:\"urls\""
Name string "json:\"name\" graphql:\"name\""
ID string "json:\"id\" graphql:\"id\""
Urls []*URLFragment "json:\"urls\" graphql:\"urls\""
Parent *struct {
Name string "json:\"name\" graphql:\"name\""
ID string "json:\"id\" graphql:\"id\""
} "json:\"parent\" graphql:\"parent\""
Images []*ImageFragment "json:\"images\" graphql:\"images\""
}
type TagFragment struct {
@@ -215,6 +220,9 @@ type FindPerformerByID struct {
type FindSceneByID struct {
FindScene *SceneFragment "json:\"findScene\" graphql:\"findScene\""
}
type FindStudio struct {
FindStudio *StudioFragment "json:\"findStudio\" graphql:\"findStudio\""
}
type SubmitFingerprint struct {
SubmitFingerprint bool "json:\"submitFingerprint\" graphql:\"submitFingerprint\""
}
@@ -239,12 +247,77 @@ const FindSceneByFingerprintDocument = `query FindSceneByFingerprint ($fingerpri
... SceneFragment
}
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment SceneFragment on Scene {
id
title
code
details
director
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment TagFragment on Tag {
name
id
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment PerformerFragment on Performer {
id
name
@@ -279,76 +352,15 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment SceneFragment on Scene {
id
title
code
details
director
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
images {
... ImageFragment
}
}
fragment TagFragment on Tag {
name
id
}
`
func (c *Client) FindSceneByFingerprint(ctx context.Context, fingerprint FingerprintQueryInput, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByFingerprint, error) {
@@ -369,6 +381,49 @@ const FindScenesByFullFingerprintsDocument = `query FindScenesByFullFingerprints
... SceneFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment PerformerFragment on Performer {
id
name
@@ -403,16 +458,6 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment SceneFragment on Scene {
id
title
@@ -440,35 +485,6 @@ fragment SceneFragment on Scene {
... FingerprintFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
images {
... ImageFragment
}
}
fragment TagFragment on Tag {
name
id
@@ -499,28 +515,56 @@ const FindScenesBySceneFingerprintsDocument = `query FindScenesBySceneFingerprin
... SceneFragment
}
}
fragment StudioFragment on Studio {
fragment URLFragment on URL {
url
type
}
fragment TagFragment on Tag {
name
id
urls {
... URLFragment
}
images {
... ImageFragment
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment URLFragment on URL {
url
type
fragment SceneFragment on Scene {
id
title
code
details
director
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment ImageFragment on Image {
id
@@ -528,10 +572,18 @@ fragment ImageFragment on Image {
width
height
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment PerformerFragment on Performer {
@@ -568,46 +620,14 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment SceneFragment on Scene {
id
title
code
details
director
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
date
urls {
... URLFragment
}
images {
... ImageFragment
}
studio {
... StudioFragment
}
tags {
... TagFragment
}
performers {
... PerformerAppearanceFragment
}
fingerprints {
... FingerprintFragment
}
}
fragment TagFragment on Tag {
name
id
}
`
@@ -629,6 +649,29 @@ const SearchSceneDocument = `query SearchScene ($term: String!) {
... SceneFragment
}
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment TagFragment on Tag {
name
id
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment SceneFragment on Scene {
id
title
@@ -660,32 +703,16 @@ fragment URLFragment on URL {
url
type
}
fragment TagFragment on Tag {
name
id
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment ImageFragment on Image {
id
url
width
height
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
@@ -730,14 +757,11 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
`
@@ -759,16 +783,6 @@ const SearchPerformerDocument = `query SearchPerformer ($term: String!) {
... PerformerFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment BodyModificationFragment on BodyModification {
location
description
@@ -817,6 +831,16 @@ fragment ImageFragment on Image {
width
height
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
`
func (c *Client) SearchPerformer(ctx context.Context, term string, httpRequestOptions ...client.HTTPRequestOption) (*SearchPerformer, error) {
@@ -915,26 +939,25 @@ const FindSceneByIDDocument = `query FindSceneByID ($id: ID!) {
... SceneFragment
}
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment URLFragment on URL {
fragment ImageFragment on Image {
id
url
type
width
height
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment TagFragment on Tag {
name
@@ -974,13 +997,11 @@ fragment PerformerFragment on Performer {
... BodyModificationFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
fragment BodyModificationFragment on BodyModification {
location
description
fragment MeasurementsFragment on Measurements {
band_size
cup_size
waist
hip
}
fragment SceneFragment on Scene {
id
@@ -1009,22 +1030,29 @@ fragment SceneFragment on Scene {
... FingerprintFragment
}
}
fragment ImageFragment on Image {
id
fragment URLFragment on URL {
url
width
height
type
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
images {
... ImageFragment
fragment BodyModificationFragment on BodyModification {
location
description
}
fragment FingerprintFragment on Fingerprint {
algorithm
hash
duration
}
fragment PerformerAppearanceFragment on PerformerAppearance {
as
performer {
... PerformerFragment
}
}
fragment FuzzyDateFragment on FuzzyDate {
date
accuracy
}
`
func (c *Client) FindSceneByID(ctx context.Context, id string, httpRequestOptions ...client.HTTPRequestOption) (*FindSceneByID, error) {
@@ -1040,6 +1068,51 @@ func (c *Client) FindSceneByID(ctx context.Context, id string, httpRequestOption
return &res, nil
}
const FindStudioDocument = `query FindStudio ($id: ID, $name: String) {
findStudio(id: $id, name: $name) {
... StudioFragment
}
}
fragment StudioFragment on Studio {
name
id
urls {
... URLFragment
}
parent {
name
id
}
images {
... ImageFragment
}
}
fragment URLFragment on URL {
url
type
}
fragment ImageFragment on Image {
id
url
width
height
}
`
func (c *Client) FindStudio(ctx context.Context, id *string, name *string, httpRequestOptions ...client.HTTPRequestOption) (*FindStudio, error) {
vars := map[string]interface{}{
"id": id,
"name": name,
}
var res FindStudio
if err := c.Client.Post(ctx, "FindStudio", FindStudioDocument, &res, vars, httpRequestOptions...); err != nil {
return nil, err
}
return &res, nil
}
const SubmitFingerprintDocument = `mutation SubmitFingerprint ($input: FingerprintSubmission!) {
submitFingerprint(input: $input)
}

View File

@@ -88,9 +88,9 @@ type DraftEntity struct {
ID *string `json:"id,omitempty"`
}
func (DraftEntity) IsSceneDraftPerformer() {}
func (DraftEntity) IsSceneDraftStudio() {}
func (DraftEntity) IsSceneDraftTag() {}
func (DraftEntity) IsSceneDraftStudio() {}
func (DraftEntity) IsSceneDraftPerformer() {}
type DraftEntityInput struct {
Name string `json:"name"`
@@ -116,6 +116,7 @@ type Edit struct {
// Objects to merge with the target. Only applicable to merges
MergeSources []EditTarget `json:"merge_sources,omitempty"`
Operation OperationEnum `json:"operation"`
Bot bool `json:"bot"`
Details EditDetails `json:"details,omitempty"`
// Previous state of fields being modified - null if operation is create or delete.
OldDetails EditDetails `json:"old_details,omitempty"`
@@ -154,6 +155,8 @@ type EditInput struct {
// Only required for merge type
MergeSourceIds []string `json:"merge_source_ids,omitempty"`
Comment *string `json:"comment,omitempty"`
// Edit submitted by an automated script. Requires bot permission
Bot *bool `json:"bot,omitempty"`
}
type EditQueryInput struct {
@@ -172,11 +175,15 @@ type EditQueryInput struct {
// Filter by target id
TargetID *string `json:"target_id,omitempty"`
// Filter by favorite status
IsFavorite *bool `json:"is_favorite,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort EditSortEnum `json:"sort"`
IsFavorite *bool `json:"is_favorite,omitempty"`
// Filter by user voted status
Voted *UserVotedFilterEnum `json:"voted,omitempty"`
// Filter to bot edits only
IsBot *bool `json:"is_bot,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort EditSortEnum `json:"sort"`
}
type EditVote struct {
@@ -542,11 +549,24 @@ type PerformerQueryInput struct {
Tattoos *BodyModificationCriterionInput `json:"tattoos,omitempty"`
Piercings *BodyModificationCriterionInput `json:"piercings,omitempty"`
// Filter by performerfavorite status for the current user
IsFavorite *bool `json:"is_favorite,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort PerformerSortEnum `json:"sort"`
IsFavorite *bool `json:"is_favorite,omitempty"`
// Filter by a performer they have performed in scenes with
PerformedWith *string `json:"performed_with,omitempty"`
// Filter by a studio
StudioID *string `json:"studio_id,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort PerformerSortEnum `json:"sort"`
}
type PerformerScenesInput struct {
// Filter by another performer that also performs in the scenes
PerformedWith *string `json:"performed_with,omitempty"`
// Filter by a studio
StudioID *string `json:"studio_id,omitempty"`
// Filter by tags
Tags *MultiIDCriterionInput `json:"tags,omitempty"`
}
type PerformerStudio struct {
@@ -689,7 +709,9 @@ type SceneDestroyInput struct {
type SceneDraft struct {
ID *string `json:"id,omitempty"`
Title *string `json:"title,omitempty"`
Code *string `json:"code,omitempty"`
Details *string `json:"details,omitempty"`
Director *string `json:"director,omitempty"`
URL *URL `json:"url,omitempty"`
Date *string `json:"date,omitempty"`
Studio SceneDraftStudio `json:"studio,omitempty"`
@@ -774,11 +796,13 @@ type SceneQueryInput struct {
// Filter to only include scenes with these fingerprints
Fingerprints *MultiStringCriterionInput `json:"fingerprints,omitempty"`
// Filter by favorited entity
Favorites *FavoriteFilter `json:"favorites,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort SceneSortEnum `json:"sort"`
Favorites *FavoriteFilter `json:"favorites,omitempty"`
// Filter to scenes with fingerprints submitted by the user
HasFingerprintSubmissions *bool `json:"has_fingerprint_submissions,omitempty"`
Page int `json:"page"`
PerPage int `json:"per_page"`
Direction SortDirectionEnum `json:"direction"`
Sort SceneSortEnum `json:"sort"`
}
type SceneUpdateInput struct {
@@ -847,16 +871,17 @@ type StringCriterionInput struct {
}
type Studio struct {
ID string `json:"id"`
Name string `json:"name"`
Urls []*URL `json:"urls,omitempty"`
Parent *Studio `json:"parent,omitempty"`
ChildStudios []*Studio `json:"child_studios,omitempty"`
Images []*Image `json:"images,omitempty"`
Deleted bool `json:"deleted"`
IsFavorite bool `json:"is_favorite"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
ID string `json:"id"`
Name string `json:"name"`
Urls []*URL `json:"urls,omitempty"`
Parent *Studio `json:"parent,omitempty"`
ChildStudios []*Studio `json:"child_studios,omitempty"`
Images []*Image `json:"images,omitempty"`
Deleted bool `json:"deleted"`
IsFavorite bool `json:"is_favorite"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
Performers *QueryPerformersResultType `json:"performers,omitempty"`
}
func (Studio) IsSceneDraftStudio() {}
@@ -1775,6 +1800,7 @@ const (
PerformerSortEnumOCounter PerformerSortEnum = "O_COUNTER"
PerformerSortEnumCareerStartYear PerformerSortEnum = "CAREER_START_YEAR"
PerformerSortEnumDebut PerformerSortEnum = "DEBUT"
PerformerSortEnumLastScene PerformerSortEnum = "LAST_SCENE"
PerformerSortEnumCreatedAt PerformerSortEnum = "CREATED_AT"
PerformerSortEnumUpdatedAt PerformerSortEnum = "UPDATED_AT"
)
@@ -1786,6 +1812,7 @@ var AllPerformerSortEnum = []PerformerSortEnum{
PerformerSortEnumOCounter,
PerformerSortEnumCareerStartYear,
PerformerSortEnumDebut,
PerformerSortEnumLastScene,
PerformerSortEnumCreatedAt,
PerformerSortEnumUpdatedAt,
}
@@ -2136,6 +2163,51 @@ func (e TargetTypeEnum) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
type UserVotedFilterEnum string
const (
UserVotedFilterEnumAbstain UserVotedFilterEnum = "ABSTAIN"
UserVotedFilterEnumAccept UserVotedFilterEnum = "ACCEPT"
UserVotedFilterEnumReject UserVotedFilterEnum = "REJECT"
UserVotedFilterEnumNotVoted UserVotedFilterEnum = "NOT_VOTED"
)
var AllUserVotedFilterEnum = []UserVotedFilterEnum{
UserVotedFilterEnumAbstain,
UserVotedFilterEnumAccept,
UserVotedFilterEnumReject,
UserVotedFilterEnumNotVoted,
}
func (e UserVotedFilterEnum) IsValid() bool {
switch e {
case UserVotedFilterEnumAbstain, UserVotedFilterEnumAccept, UserVotedFilterEnumReject, UserVotedFilterEnumNotVoted:
return true
}
return false
}
func (e UserVotedFilterEnum) String() string {
return string(e)
}
func (e *UserVotedFilterEnum) UnmarshalGQL(v interface{}) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}
*e = UserVotedFilterEnum(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid UserVotedFilterEnum", str)
}
return nil
}
func (e UserVotedFilterEnum) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}
type ValidSiteTypeEnum string
const (

View File

@@ -2,6 +2,11 @@ package stashbox
import "github.com/stashapp/stash/pkg/models"
type StashBoxStudioQueryResult struct {
Query string `json:"query"`
Results []*models.ScrapedStudio `json:"results"`
}
type StashBoxPerformerQueryResult struct {
Query string `json:"query"`
Results []*models.ScrapedPerformer `json:"results"`

View File

@@ -18,6 +18,7 @@ import (
"golang.org/x/text/language"
"github.com/Yamashou/gqlgenc/graphqljson"
"github.com/gofrs/uuid"
"github.com/stashapp/stash/pkg/file"
"github.com/stashapp/stash/pkg/logger"
"github.com/stashapp/stash/pkg/match"
@@ -660,6 +661,26 @@ func performerFragmentToScrapedScenePerformer(p graphql.PerformerFragment) *mode
return sp
}
func studioFragmentToScrapedStudio(s graphql.StudioFragment) *models.ScrapedStudio {
images := []string{}
for _, image := range s.Images {
images = append(images, image.URL)
}
st := &models.ScrapedStudio{
Name: s.Name,
URL: findURL(s.Urls, "HOME"),
Images: images,
RemoteSiteID: &s.ID,
}
if len(st.Images) > 0 {
st.Image = &st.Images[0]
}
return st
}
func getFirstImage(ctx context.Context, client *http.Client, images []*graphql.ImageFragment) *string {
ret, err := fetchImage(ctx, client, images[0].URL)
if err != nil && !errors.Is(err, context.Canceled) {
@@ -725,20 +746,29 @@ func (c Client) sceneFragmentToScrapedScene(ctx context.Context, s *graphql.Scen
tqb := c.repository.Tag
if s.Studio != nil {
studioID := s.Studio.ID
ss.Studio = &models.ScrapedStudio{
Name: s.Studio.Name,
URL: findURL(s.Studio.Urls, "HOME"),
RemoteSiteID: &studioID,
}
if s.Studio.Images != nil && len(s.Studio.Images) > 0 {
ss.Studio.Image = &s.Studio.Images[0].URL
}
ss.Studio = studioFragmentToScrapedStudio(*s.Studio)
err := match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio, &c.box.Endpoint)
if err != nil {
return err
}
var parentStudio *graphql.FindStudio
if s.Studio.Parent != nil {
parentStudio, err = c.client.FindStudio(ctx, &s.Studio.Parent.ID, nil)
if err != nil {
return err
}
if parentStudio.FindStudio != nil {
ss.Studio.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ss.Studio.Parent, &c.box.Endpoint)
if err != nil {
return err
}
}
}
}
for _, p := range s.Performers {
@@ -799,6 +829,56 @@ func (c Client) FindStashBoxPerformerByName(ctx context.Context, name string) (*
return ret, nil
}
func (c Client) FindStashBoxStudio(ctx context.Context, query string) (*models.ScrapedStudio, error) {
var studio *graphql.FindStudio
_, err := uuid.FromString(query)
if err == nil {
// Confirmed the user passed in a Stash ID
studio, err = c.client.FindStudio(ctx, &query, nil)
} else {
// Otherwise assume they're searching on a name
studio, err = c.client.FindStudio(ctx, nil, &query)
}
if err != nil {
return nil, err
}
var ret *models.ScrapedStudio
if studio.FindStudio != nil {
if err := txn.WithReadTxn(ctx, c.txnManager, func(ctx context.Context) error {
ret = studioFragmentToScrapedStudio(*studio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret, &c.box.Endpoint)
if err != nil {
return err
}
if studio.FindStudio.Parent != nil {
parentStudio, err := c.client.FindStudio(ctx, &studio.FindStudio.Parent.ID, nil)
if err != nil {
return err
}
if parentStudio.FindStudio != nil {
ret.Parent = studioFragmentToScrapedStudio(*parentStudio.FindStudio)
err = match.ScrapedStudio(ctx, c.repository.Studio, ret.Parent, &c.box.Endpoint)
if err != nil {
return err
}
}
}
return nil
}); err != nil {
return nil, err
}
}
return ret, nil
}
func (c Client) GetUser(ctx context.Context) (*graphql.Me, error) {
return c.client.Me(ctx)
}

View File

@@ -438,21 +438,6 @@ func (r *stashIDRepository) get(ctx context.Context, id int) ([]models.StashID,
return []models.StashID(ret), err
}
func (r *stashIDRepository) replace(ctx context.Context, id int, newIDs []models.StashID) error {
if err := r.destroy(ctx, []int{id}); err != nil {
return err
}
query := fmt.Sprintf("INSERT INTO %s (%s, endpoint, stash_id) VALUES (?, ?, ?)", r.tableName, r.idColumn)
for _, stashID := range newIDs {
_, err := r.tx.Exec(ctx, query, id, stashID.Endpoint, stashID.StashID)
if err != nil {
return err
}
}
return nil
}
type filesRepository struct {
repository
}

View File

@@ -631,7 +631,7 @@ func populateDB() error {
return fmt.Errorf("error creating performers: %s", err.Error())
}
if err := createStudios(ctx, db.Studio, studiosNameCase, studiosNameNoCase); err != nil {
if err := createStudios(ctx, studiosNameCase, studiosNameNoCase); err != nil {
return fmt.Errorf("error creating studios: %s", err.Error())
}
@@ -659,7 +659,7 @@ func populateDB() error {
return fmt.Errorf("error linking movie studios: %s", err.Error())
}
if err := linkStudiosParent(ctx, db.Studio); err != nil {
if err := linkStudiosParent(ctx); err != nil {
return fmt.Errorf("error linking studios parent: %s", err.Error())
}
@@ -1310,8 +1310,8 @@ func createMovies(ctx context.Context, mqb models.MovieReaderWriter, n int, o in
name = getMovieStringValue(index, name)
movie := models.Movie{
Name: name,
URL: getMovieNullStringValue(index, urlField),
Name: name,
URL: getMovieNullStringValue(index, urlField),
}
err := mqb.Create(ctx, &movie)
@@ -1573,9 +1573,9 @@ func getStudioNullStringValue(index int, field string) string {
return ret.String
}
func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name string, parentID *int) (*models.Studio, error) {
func createStudio(ctx context.Context, sqb *sqlite.StudioStore, name string, parentID *int) (*models.Studio, error) {
studio := models.Studio{
Name: name,
Name: name,
}
if parentID != nil {
@@ -1590,7 +1590,7 @@ func createStudio(ctx context.Context, sqb models.StudioReaderWriter, name strin
return &studio, nil
}
func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, studio *models.Studio) error {
func createStudioFromModel(ctx context.Context, sqb *sqlite.StudioStore, studio *models.Studio) error {
err := sqb.Create(ctx, studio)
if err != nil {
@@ -1601,7 +1601,8 @@ func createStudioFromModel(ctx context.Context, sqb models.StudioReaderWriter, s
}
// createStudios creates n studios with plain Name and o studios with camel cased NaMe included
func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o int) error {
func createStudios(ctx context.Context, n int, o int) error {
sqb := db.Studio
const namePlain = "Name"
const nameNoCase = "NaMe"
@@ -1618,22 +1619,18 @@ func createStudios(ctx context.Context, sqb models.StudioReaderWriter, n int, o
name = getStudioStringValue(index, name)
studio := models.Studio{
Name: name,
URL: getStudioNullStringValue(index, urlField),
URL: getStudioStringValue(index, urlField),
IgnoreAutoTag: getIgnoreAutoTag(i),
}
err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return err
}
// add alias
// only add aliases for some scenes
if i == studioIdxWithMovie || i%5 == 0 {
alias := getStudioStringValue(i, "Alias")
if err := sqb.UpdateAliases(ctx, studio.ID, []string{alias}); err != nil {
return fmt.Errorf("error setting studio alias: %s", err.Error())
}
studio.Aliases = models.NewRelatedStrings([]string{alias})
}
err := createStudioFromModel(ctx, sqb, &studio)
if err != nil {
return err
}
studioIDs = append(studioIDs, studio.ID)
@@ -1756,12 +1753,14 @@ func linkMovieStudios(ctx context.Context, mqb models.MovieWriter) error {
})
}
func linkStudiosParent(ctx context.Context, qb models.StudioWriter) error {
func linkStudiosParent(ctx context.Context) error {
qb := db.Studio
return doLinks(studioParentLinks, func(parentIndex, childIndex int) error {
studio := models.StudioPartial{
input := &models.StudioPartial{
ID: studioIDs[childIndex],
ParentID: models.NewOptionalInt(studioIDs[parentIndex]),
}
_, err := qb.UpdatePartial(ctx, studioIDs[childIndex], studio)
_, err := qb.UpdatePartial(ctx, *input)
return err
})

View File

@@ -5,7 +5,6 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/doug-martin/goqu/v9"
"github.com/doug-martin/goqu/v9/exp"
@@ -15,14 +14,16 @@ import (
"github.com/stashapp/stash/pkg/models"
"github.com/stashapp/stash/pkg/sliceutil/intslice"
"github.com/stashapp/stash/pkg/studio"
)
const (
studioTable = "studios"
studioIDColumn = "studio_id"
studioAliasesTable = "studio_aliases"
studioAliasColumn = "alias"
studioTable = "studios"
studioIDColumn = "studio_id"
studioAliasesTable = "studio_aliases"
studioAliasColumn = "alias"
studioParentIDColumn = "parent_id"
studioNameColumn = "name"
studioImageBlobColumn = "image_blob"
)
@@ -39,7 +40,7 @@ type studioRow struct {
IgnoreAutoTag bool `db:"ignore_auto_tag"`
// not used in resolutions or updates
CoverBlob zero.String `db:"image_blob"`
ImageBlob zero.String `db:"image_blob"`
}
func (r *studioRow) fromStudio(o models.Studio) {
@@ -116,6 +117,8 @@ func (qb *StudioStore) selectDataset() *goqu.SelectDataset {
}
func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) error {
var err error
var r studioRow
r.fromStudio(*newObject)
@@ -124,34 +127,66 @@ func (qb *StudioStore) Create(ctx context.Context, newObject *models.Studio) err
return err
}
updated, err := qb.find(ctx, id)
if newObject.Aliases.Loaded() {
if err := studio.EnsureAliasesUnique(ctx, id, newObject.Aliases.List(), qb); err != nil {
return err
}
if err := studiosAliasesTableMgr.insertJoins(ctx, id, newObject.Aliases.List()); err != nil {
return err
}
}
if newObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.insertJoins(ctx, id, newObject.StashIDs.List()); err != nil {
return err
}
}
updated, _ := qb.find(ctx, id)
if err != nil {
return fmt.Errorf("finding after create: %w", err)
}
*newObject = *updated
return nil
}
func (qb *StudioStore) UpdatePartial(ctx context.Context, id int, partial models.StudioPartial) (*models.Studio, error) {
func (qb *StudioStore) UpdatePartial(ctx context.Context, input models.StudioPartial) (*models.Studio, error) {
r := studioRowRecord{
updateRecord{
Record: make(exp.Record),
},
}
r.fromPartial(partial)
r.fromPartial(input)
if len(r.Record) > 0 {
if err := qb.tableMgr.updateByID(ctx, id, r.Record); err != nil {
if err := qb.tableMgr.updateByID(ctx, input.ID, r.Record); err != nil {
return nil, err
}
}
return qb.find(ctx, id)
if input.Aliases != nil {
if err := studio.EnsureAliasesUnique(ctx, input.ID, input.Aliases.Values, qb); err != nil {
return nil, err
}
if err := studiosAliasesTableMgr.modifyJoins(ctx, input.ID, input.Aliases.Values, input.Aliases.Mode); err != nil {
return nil, err
}
}
if input.StashIDs != nil {
if err := studiosStashIDsTableMgr.modifyJoins(ctx, input.ID, input.StashIDs.StashIDs, input.StashIDs.Mode); err != nil {
return nil, err
}
}
return qb.Find(ctx, input.ID)
}
// This is only used by the Import/Export functionality
func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio) error {
var r studioRow
r.fromStudio(*updatedObject)
@@ -160,6 +195,18 @@ func (qb *StudioStore) Update(ctx context.Context, updatedObject *models.Studio)
return err
}
if updatedObject.Aliases.Loaded() {
if err := studiosAliasesTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.Aliases.List()); err != nil {
return err
}
}
if updatedObject.StashIDs.Loaded() {
if err := studiosStashIDsTableMgr.replaceJoins(ctx, updatedObject.ID, updatedObject.StashIDs.List()); err != nil {
return err
}
}
return nil
}
@@ -257,10 +304,22 @@ func (qb *StudioStore) getMany(ctx context.Context, q *goqu.SelectDataset) ([]*m
return ret, nil
}
func (qb *StudioStore) findBySubquery(ctx context.Context, sq *goqu.SelectDataset) ([]*models.Studio, error) {
table := qb.table()
q := qb.selectDataset().Where(
table.Col(idColumn).Eq(
sq,
),
)
return qb.getMany(ctx, q)
}
func (qb *StudioStore) FindChildren(ctx context.Context, id int) ([]*models.Studio, error) {
// SELECT studios.* FROM studios WHERE studios.parent_id = ?
table := qb.table()
sq := qb.selectDataset().Where(table.Col("parent_id").Eq(id))
sq := qb.selectDataset().Where(table.Col(studioParentIDColumn).Eq(id))
ret, err := qb.getMany(ctx, sq)
if err != nil {
@@ -309,13 +368,44 @@ func (qb *StudioStore) FindByName(ctx context.Context, name string, nocase bool)
}
func (qb *StudioStore) FindByStashID(ctx context.Context, stashID models.StashID) ([]*models.Studio, error) {
query := selectAll("studios") + `
LEFT JOIN studio_stash_ids on studio_stash_ids.studio_id = studios.id
WHERE studio_stash_ids.stash_id = ?
AND studio_stash_ids.endpoint = ?
`
args := []interface{}{stashID.StashID, stashID.Endpoint}
return qb.queryStudios(ctx, query, args)
sq := dialect.From(studiosStashIDsJoinTable).Select(studiosStashIDsJoinTable.Col(studioIDColumn)).Where(
studiosStashIDsJoinTable.Col("stash_id").Eq(stashID.StashID),
studiosStashIDsJoinTable.Col("endpoint").Eq(stashID.Endpoint),
)
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting studios for stash ID %s: %w", stashID.StashID, err)
}
return ret, nil
}
func (qb *StudioStore) FindByStashIDStatus(ctx context.Context, hasStashID bool, stashboxEndpoint string) ([]*models.Studio, error) {
table := qb.table()
sq := dialect.From(table).LeftJoin(
studiosStashIDsJoinTable,
goqu.On(table.Col(idColumn).Eq(studiosStashIDsJoinTable.Col(studioIDColumn))),
).Select(table.Col(idColumn))
if hasStashID {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNotNull(),
studiosStashIDsJoinTable.Col("endpoint").Eq(stashboxEndpoint),
)
} else {
sq = sq.Where(
studiosStashIDsJoinTable.Col("stash_id").IsNull(),
)
}
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting studios for stash-box endpoint %s: %w", stashboxEndpoint, err)
}
return ret, nil
}
func (qb *StudioStore) Count(ctx context.Context) (int, error) {
@@ -325,38 +415,37 @@ func (qb *StudioStore) Count(ctx context.Context) (int, error) {
func (qb *StudioStore) All(ctx context.Context) ([]*models.Studio, error) {
table := qb.table()
return qb.getMany(ctx, qb.selectDataset().Order(
table.Col("name").Asc(),
table.Col(idColumn).Asc(),
))
return qb.getMany(ctx, qb.selectDataset().Order(table.Col(studioNameColumn).Asc()))
}
func (qb *StudioStore) QueryForAutoTag(ctx context.Context, words []string) ([]*models.Studio, error) {
// TODO - Query needs to be changed to support queries of this type, and
// this method should be removed
query := selectAll(studioTable)
query += " LEFT JOIN studio_aliases ON studio_aliases.studio_id = studios.id"
table := qb.table()
sq := dialect.From(table).Select(table.Col(idColumn)).LeftJoin(
studiosAliasesJoinTable,
goqu.On(studiosAliasesJoinTable.Col(studioIDColumn).Eq(table.Col(idColumn))),
)
var whereClauses []string
var args []interface{}
var whereClauses []exp.Expression
for _, w := range words {
ww := w + "%"
whereClauses = append(whereClauses, "studios.name like ?")
args = append(args, ww)
// include aliases
whereClauses = append(whereClauses, "studio_aliases.alias like ?")
args = append(args, ww)
whereClauses = append(whereClauses, table.Col(studioNameColumn).Like(w+"%"))
whereClauses = append(whereClauses, studiosAliasesJoinTable.Col("alias").Like(w+"%"))
}
whereOr := "(" + strings.Join(whereClauses, " OR ") + ")"
where := strings.Join([]string{
"studios.ignore_auto_tag = 0",
whereOr,
}, " AND ")
return qb.queryStudios(ctx, query+" WHERE "+where, args)
sq = sq.Where(
goqu.Or(whereClauses...),
table.Col("ignore_auto_tag").Eq(0),
)
ret, err := qb.findBySubquery(ctx, sq)
if err != nil {
return nil, fmt.Errorf("getting performers for autotag: %w", err)
}
return ret, nil
}
func (qb *StudioStore) validateFilter(filter *models.StudioFilterType) error {
@@ -430,13 +519,13 @@ func (qb *StudioStore) makeFilter(ctx context.Context, studioFilter *models.Stud
query.handleCriterion(ctx, studioGalleryCountCriterionHandler(qb, studioFilter.GalleryCount))
query.handleCriterion(ctx, studioParentCriterionHandler(qb, studioFilter.Parents))
query.handleCriterion(ctx, studioAliasCriterionHandler(qb, studioFilter.Aliases))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, "studios.created_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, "studios.updated_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.CreatedAt, studioTable+".created_at"))
query.handleCriterion(ctx, timestampCriterionHandler(studioFilter.UpdatedAt, studioTable+".updated_at"))
return query
}
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
func (qb *StudioStore) makeQuery(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) (*queryBuilder, error) {
if studioFilter == nil {
studioFilter = &models.StudioFilterType{}
}
@@ -450,20 +539,29 @@ func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFil
if q := findFilter.Q; q != nil && *q != "" {
query.join(studioAliasesTable, "", "studio_aliases.studio_id = studios.id")
searchColumns := []string{"studios.name", "studio_aliases.alias"}
query.parseQueryString(searchColumns, *q)
}
if err := qb.validateFilter(studioFilter); err != nil {
return nil, 0, err
return nil, err
}
filter := qb.makeFilter(ctx, studioFilter)
if err := query.addFilter(filter); err != nil {
return nil, 0, err
return nil, err
}
query.sortAndPagination = qb.getStudioSort(findFilter) + getPagination(findFilter)
return &query, nil
}
func (qb *StudioStore) Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error) {
query, err := qb.makeQuery(ctx, studioFilter, findFilter)
if err != nil {
return nil, 0, err
}
idsResult, countResult, err := query.executeFind(ctx)
if err != nil {
return nil, 0, err
@@ -546,7 +644,7 @@ func studioAliasCriterionHandler(qb *StudioStore, alias *models.StringCriterionI
joinTable: studioAliasesTable,
stringColumn: studioAliasColumn,
addJoinTable: func(f *filterBuilder) {
qb.aliasRepository().join(f, "", "studios.id")
studiosAliasesTableMgr.join(f, "", "studios.id")
},
}
@@ -581,26 +679,6 @@ func (qb *StudioStore) getStudioSort(findFilter *models.FindFilterType) string {
return sortQuery
}
func (qb *StudioStore) queryStudios(ctx context.Context, query string, args []interface{}) ([]*models.Studio, error) {
const single = false
var ret []*models.Studio
if err := qb.queryFunc(ctx, query, args, single, func(r *sqlx.Rows) error {
var f studioRow
if err := r.StructScan(&f); err != nil {
return err
}
s := f.resolve()
ret = append(ret, s)
return nil
}); err != nil {
return nil, err
}
return ret, nil
}
func (qb *StudioStore) GetImage(ctx context.Context, studioID int) ([]byte, error) {
return qb.blobJoinQueryBuilder.GetImage(ctx, studioID, studioImageBlobColumn)
}
@@ -628,28 +706,9 @@ func (qb *StudioStore) stashIDRepository() *stashIDRepository {
}
func (qb *StudioStore) GetStashIDs(ctx context.Context, studioID int) ([]models.StashID, error) {
return qb.stashIDRepository().get(ctx, studioID)
}
func (qb *StudioStore) UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error {
return qb.stashIDRepository().replace(ctx, studioID, stashIDs)
}
func (qb *StudioStore) aliasRepository() *stringRepository {
return &stringRepository{
repository: repository{
tx: qb.tx,
tableName: studioAliasesTable,
idColumn: studioIDColumn,
},
stringColumn: studioAliasColumn,
}
return studiosStashIDsTableMgr.get(ctx, studioID)
}
func (qb *StudioStore) GetAliases(ctx context.Context, studioID int) ([]string, error) {
return qb.aliasRepository().get(ctx, studioID)
}
func (qb *StudioStore) UpdateAliases(ctx context.Context, studioID int, aliases []string) error {
return qb.aliasRepository().replace(ctx, studioID, aliases)
return studiosAliasesTableMgr.get(ctx, studioID)
}

View File

@@ -219,18 +219,15 @@ func TestStudioQueryForAutoTag(t *testing.T) {
assert.Len(t, studios, 1)
assert.Equal(t, strings.ToLower(studioNames[studioIdxWithMovie]), strings.ToLower(studios[0].Name))
// find by alias
name = getStudioStringValue(studioIdxWithMovie, "Alias")
studios, err = tqb.QueryForAutoTag(ctx, []string{name})
if err != nil {
t.Errorf("Error finding studios: %s", err.Error())
}
if assert.Len(t, studios, 1) {
assert.Equal(t, studioIDs[studioIdxWithMovie], studios[0].ID)
}
return nil
})
}
@@ -363,11 +360,12 @@ func TestStudioUpdateClearParent(t *testing.T) {
sqb := db.Studio
// clear the parent id from the child
updatePartial := models.StudioPartial{
input := models.StudioPartial{
ID: createdChild.ID,
ParentID: models.NewOptionalIntPtr(nil),
}
updatedStudio, err := sqb.UpdatePartial(ctx, createdChild.ID, updatePartial)
updatedStudio, err := sqb.UpdatePartial(ctx, input)
if err != nil {
return fmt.Errorf("Error updated studio: %s", err.Error())
@@ -548,7 +546,7 @@ func verifyStudiosGalleryCount(t *testing.T, galleryCountCriterion models.IntCri
}
func TestStudioStashIDs(t *testing.T) {
if err := withTxn(func(ctx context.Context) error {
if err := withRollbackTxn(func(ctx context.Context) error {
qb := db.Studio
// create studio to test against
@@ -558,13 +556,83 @@ func TestStudioStashIDs(t *testing.T) {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
testStashIDReaderWriter(ctx, t, qb, created.ID)
studio, err := qb.Find(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting studio: %s", err.Error())
}
if err := studio.LoadStashIDs(ctx, qb); err != nil {
return err
}
testStudioStashIDs(ctx, t, studio)
return nil
}); err != nil {
t.Error(err.Error())
}
}
func testStudioStashIDs(ctx context.Context, t *testing.T, s *models.Studio) {
qb := db.Studio
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
// ensure no stash IDs to begin with
assert.Len(t, s.StashIDs.List(), 0)
// add stash ids
const stashIDStr = "stashID"
const endpoint = "endpoint"
stashID := models.StashID{
StashID: stashIDStr,
Endpoint: endpoint,
}
// update stash ids and ensure was updated
input := models.StudioPartial{
ID: s.ID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{stashID},
Mode: models.RelationshipUpdateModeSet,
},
}
var err error
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Equal(t, []models.StashID{stashID}, s.StashIDs.List())
// remove stash ids and ensure was updated
input = models.StudioPartial{
ID: s.ID,
StashIDs: &models.UpdateStashIDs{
StashIDs: []models.StashID{stashID},
Mode: models.RelationshipUpdateModeRemove,
},
}
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadStashIDs(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Len(t, s.StashIDs.List(), 0)
}
func TestStudioQueryURL(t *testing.T) {
const sceneIdx = 1
studioURL := getStudioStringValue(sceneIdx, urlField)
@@ -684,7 +752,7 @@ func TestStudioQueryIsMissingRating(t *testing.T) {
assert.True(t, len(studios) > 0)
for _, studio := range studios {
assert.True(t, studio.Rating == nil)
assert.Nil(t, studio.Rating)
}
return nil
@@ -778,36 +846,87 @@ func TestStudioQueryAlias(t *testing.T) {
verifyStudioQuery(t, studioFilter, verifyFn)
}
func TestStudioUpdateAlias(t *testing.T) {
if err := withTxn(func(ctx context.Context) error {
func TestStudioAlias(t *testing.T) {
if err := withRollbackTxn(func(ctx context.Context) error {
qb := db.Studio
// create studio to test against
const name = "TestStudioUpdateAlias"
created, err := createStudio(ctx, qb, name, nil)
const name = "TestStudioAlias"
created, err := createStudio(ctx, db.Studio, name, nil)
if err != nil {
return fmt.Errorf("Error creating studio: %s", err.Error())
}
aliases := []string{"alias1", "alias2"}
err = qb.UpdateAliases(ctx, created.ID, aliases)
studio, err := qb.Find(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error updating studio aliases: %s", err.Error())
return fmt.Errorf("Error getting studio: %s", err.Error())
}
// ensure aliases set
storedAliases, err := qb.GetAliases(ctx, created.ID)
if err != nil {
return fmt.Errorf("Error getting aliases: %s", err.Error())
if err := studio.LoadStashIDs(ctx, qb); err != nil {
return err
}
assert.Equal(t, aliases, storedAliases)
testStudioAlias(ctx, t, studio)
return nil
}); err != nil {
t.Error(err.Error())
}
}
func testStudioAlias(ctx context.Context, t *testing.T, s *models.Studio) {
qb := db.Studio
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
// ensure no alias to begin with
assert.Len(t, s.Aliases.List(), 0)
aliases := []string{"alias1", "alias2"}
// update alias and ensure was updated
input := models.StudioPartial{
ID: s.ID,
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeSet,
},
}
var err error
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Equal(t, aliases, s.Aliases.List())
// remove alias and ensure was updated
input = models.StudioPartial{
ID: s.ID,
Aliases: &models.UpdateStrings{
Values: aliases,
Mode: models.RelationshipUpdateModeRemove,
},
}
s, err = qb.UpdatePartial(ctx, input)
if err != nil {
t.Error(err.Error())
}
if err := s.LoadAliases(ctx, qb); err != nil {
t.Error(err.Error())
return
}
assert.Len(t, s.Aliases.List(), 0)
}
// TestStudioQueryFast does a quick test for major errors, no result verification
func TestStudioQueryFast(t *testing.T) {

View File

@@ -29,6 +29,9 @@ var (
performersAliasesJoinTable = goqu.T(performersAliasesTable)
performersTagsJoinTable = goqu.T(performersTagsTable)
performersStashIDsJoinTable = goqu.T("performer_stash_ids")
studiosAliasesJoinTable = goqu.T(studioAliasesTable)
studiosStashIDsJoinTable = goqu.T("studio_stash_ids")
)
var (
@@ -233,6 +236,21 @@ var (
table: goqu.T(studioTable),
idColumn: goqu.T(studioTable).Col(idColumn),
}
studiosAliasesTableMgr = &stringTable{
table: table{
table: studiosAliasesJoinTable,
idColumn: studiosAliasesJoinTable.Col(studioIDColumn),
},
stringColumn: studiosAliasesJoinTable.Col(studioAliasColumn),
}
studiosStashIDsTableMgr = &stashIDTable{
table: table{
table: studiosStashIDsJoinTable,
idColumn: studiosStashIDsJoinTable.Col(studioIDColumn),
},
}
)
var (

View File

@@ -11,15 +11,15 @@ import (
"github.com/stashapp/stash/pkg/utils"
)
type FinderImageStashIDGetter interface {
type FinderImageAliasStashIDGetter interface {
Finder
GetAliases(ctx context.Context, studioID int) ([]string, error)
GetImage(ctx context.Context, studioID int) ([]byte, error)
models.AliasLoader
models.StashIDLoader
}
// ToJSON converts a Studio object into its JSON equivalent.
func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models.Studio) (*jsonschema.Studio, error) {
func ToJSON(ctx context.Context, reader FinderImageAliasStashIDGetter, studio *models.Studio) (*jsonschema.Studio, error) {
newStudioJSON := jsonschema.Studio{
Name: studio.Name,
URL: studio.URL,
@@ -44,12 +44,15 @@ func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models
newStudioJSON.Rating = *studio.Rating
}
aliases, err := reader.GetAliases(ctx, studio.ID)
if err != nil {
return nil, fmt.Errorf("error getting studio aliases: %v", err)
if err := studio.LoadAliases(ctx, reader); err != nil {
return nil, fmt.Errorf("loading studio aliases: %w", err)
}
newStudioJSON.Aliases = studio.Aliases.List()
newStudioJSON.Aliases = aliases
if err := studio.LoadStashIDs(ctx, reader); err != nil {
return nil, fmt.Errorf("loading studio stash ids: %w", err)
}
newStudioJSON.StashIDs = studio.StashIDs.List()
image, err := reader.GetImage(ctx, studio.ID)
if err != nil {
@@ -60,17 +63,5 @@ func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models
newStudioJSON.Image = utils.GetBase64StringFromData(image)
}
stashIDs, _ := reader.GetStashIDs(ctx, studio.ID)
var ret []models.StashID
for _, stashID := range stashIDs {
newJoin := models.StashID{
StashID: stashID.StashID,
Endpoint: stashID.Endpoint,
}
ret = append(ret, newJoin)
}
newStudioJSON.StashIDs = ret
return &newStudioJSON, nil
}

View File

@@ -15,12 +15,10 @@ import (
)
const (
studioID = 1
noImageID = 2
errImageID = 3
missingParentStudioID = 4
errStudioID = 5
errAliasID = 6
parentStudioID = 10
missingStudioID = 11
@@ -31,17 +29,19 @@ var (
studioName = "testStudio"
url = "url"
details = "details"
rating = 5
parentStudioName = "parentStudio"
autoTagIgnored = true
)
var studioID = 1
var rating = 5
var parentStudio models.Studio = models.Studio{
Name: parentStudioName,
}
var imageBytes = []byte("imageBytes")
var aliases = []string{"alias"}
var stashID = models.StashID{
StashID: "StashID",
Endpoint: "Endpoint",
@@ -67,6 +67,8 @@ func createFullStudio(id int, parentID int) models.Studio {
UpdatedAt: updateTime,
Rating: &rating,
IgnoreAutoTag: autoTagIgnored,
Aliases: models.NewRelatedStrings(aliases),
StashIDs: models.NewRelatedStashIDs(stashIDs),
}
if parentID != 0 {
@@ -81,6 +83,8 @@ func createEmptyStudio(id int) models.Studio {
ID: id,
CreatedAt: createTime,
UpdatedAt: updateTime,
Aliases: models.NewRelatedStrings([]string{}),
StashIDs: models.NewRelatedStashIDs([]models.StashID{}),
}
}
@@ -95,13 +99,11 @@ func createFullJSONStudio(parentStudio, image string, aliases []string) *jsonsch
UpdatedAt: json.JSONTime{
Time: updateTime,
},
ParentStudio: parentStudio,
Image: image,
Rating: rating,
Aliases: aliases,
StashIDs: []models.StashID{
stashID,
},
ParentStudio: parentStudio,
Image: image,
Rating: rating,
Aliases: aliases,
StashIDs: stashIDs,
IgnoreAutoTag: autoTagIgnored,
}
}
@@ -114,6 +116,8 @@ func createEmptyJSONStudio() *jsonschema.Studio {
UpdatedAt: json.JSONTime{
Time: updateTime,
},
Aliases: []string{},
StashIDs: []models.StashID{},
}
}
@@ -139,13 +143,13 @@ func initTestTable() {
},
{
createFullStudio(errImageID, parentStudioID),
createFullJSONStudio(parentStudioName, "", nil),
createFullJSONStudio(parentStudioName, "", []string{"alias"}),
// failure to get image is not an error
false,
},
{
createFullStudio(missingParentStudioID, missingStudioID),
createFullJSONStudio("", image, nil),
createFullJSONStudio("", image, []string{"alias"}),
false,
},
{
@@ -153,11 +157,6 @@ func initTestTable() {
nil,
true,
},
{
createFullStudio(errAliasID, parentStudioID),
nil,
true,
},
}
}
@@ -174,7 +173,6 @@ func TestToJSON(t *testing.T) {
mockStudioReader.On("GetImage", ctx, errImageID).Return(nil, imageErr).Once()
mockStudioReader.On("GetImage", ctx, missingParentStudioID).Return(imageBytes, nil).Maybe()
mockStudioReader.On("GetImage", ctx, errStudioID).Return(imageBytes, nil).Maybe()
mockStudioReader.On("GetImage", ctx, errAliasID).Return(imageBytes, nil).Maybe()
parentStudioErr := errors.New("error getting parent studio")
@@ -182,19 +180,6 @@ func TestToJSON(t *testing.T) {
mockStudioReader.On("Find", ctx, missingStudioID).Return(nil, nil)
mockStudioReader.On("Find", ctx, errParentStudioID).Return(nil, parentStudioErr)
aliasErr := errors.New("error getting aliases")
mockStudioReader.On("GetAliases", ctx, studioID).Return([]string{"alias"}, nil).Once()
mockStudioReader.On("GetAliases", ctx, noImageID).Return(nil, nil).Once()
mockStudioReader.On("GetAliases", ctx, errImageID).Return(nil, nil).Once()
mockStudioReader.On("GetAliases", ctx, missingParentStudioID).Return(nil, nil).Once()
mockStudioReader.On("GetAliases", ctx, errAliasID).Return(nil, aliasErr).Once()
mockStudioReader.On("GetStashIDs", ctx, studioID).Return(stashIDs, nil).Once()
mockStudioReader.On("GetStashIDs", ctx, noImageID).Return(nil, nil).Once()
mockStudioReader.On("GetStashIDs", ctx, missingParentStudioID).Return(stashIDs, nil).Once()
mockStudioReader.On("GetStashIDs", ctx, errImageID).Return(stashIDs, nil).Once()
for i, s := range scenarios {
studio := s.input
json, err := ToJSON(ctx, mockStudioReader, &studio)

View File

@@ -14,8 +14,6 @@ type NameFinderCreatorUpdater interface {
NameFinderCreator
Update(ctx context.Context, updatedStudio *models.Studio) error
UpdateImage(ctx context.Context, studioID int, image []byte) error
UpdateAliases(ctx context.Context, studioID int, aliases []string) error
UpdateStashIDs(ctx context.Context, studioID int, stashIDs []models.StashID) error
}
var ErrParentStudioNotExist = errors.New("parent studio does not exist")
@@ -25,20 +23,13 @@ type Importer struct {
Input jsonschema.Studio
MissingRefBehaviour models.ImportMissingRefEnum
ID int
studio models.Studio
imageData []byte
}
func (i *Importer) PreImport(ctx context.Context) error {
i.studio = models.Studio{
Name: i.Input.Name,
URL: i.Input.URL,
Details: i.Input.Details,
IgnoreAutoTag: i.Input.IgnoreAutoTag,
CreatedAt: i.Input.CreatedAt.GetTime(),
UpdatedAt: i.Input.UpdatedAt.GetTime(),
Rating: &i.Input.Rating,
}
i.studio = studioJSONtoStudio(i.Input)
if err := i.populateParentStudio(ctx); err != nil {
return err
@@ -87,7 +78,9 @@ func (i *Importer) populateParentStudio(ctx context.Context) error {
}
func (i *Importer) createParentStudio(ctx context.Context, name string) (int, error) {
newStudio := models.NewStudio(name)
newStudio := &models.Studio{
Name: name,
}
err := i.ReaderWriter.Create(ctx, newStudio)
if err != nil {
@@ -104,16 +97,6 @@ func (i *Importer) PostImport(ctx context.Context, id int) error {
}
}
if len(i.Input.StashIDs) > 0 {
if err := i.ReaderWriter.UpdateStashIDs(ctx, id, i.Input.StashIDs); err != nil {
return fmt.Errorf("error setting stash id: %v", err)
}
}
if err := i.ReaderWriter.UpdateAliases(ctx, id, i.Input.Aliases); err != nil {
return fmt.Errorf("error setting tag aliases: %v", err)
}
return nil
}
@@ -156,3 +139,23 @@ func (i *Importer) Update(ctx context.Context, id int) error {
return nil
}
func studioJSONtoStudio(studioJSON jsonschema.Studio) models.Studio {
newStudio := models.Studio{
Name: studioJSON.Name,
URL: studioJSON.URL,
Aliases: models.NewRelatedStrings(studioJSON.Aliases),
Details: studioJSON.Details,
IgnoreAutoTag: studioJSON.IgnoreAutoTag,
CreatedAt: studioJSON.CreatedAt.GetTime(),
UpdatedAt: studioJSON.UpdatedAt.GetTime(),
StashIDs: models.NewRelatedStashIDs(studioJSON.StashIDs),
}
if studioJSON.Rating != 0 {
newStudio.Rating = &studioJSON.Rating
}
return newStudio
}

View File

@@ -164,15 +164,9 @@ func TestImporterPostImport(t *testing.T) {
}
updateStudioImageErr := errors.New("UpdateImage error")
updateTagAliasErr := errors.New("UpdateAlias error")
readerWriter.On("UpdateImage", ctx, studioID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateImage", ctx, errImageID, imageBytes).Return(updateStudioImageErr).Once()
readerWriter.On("UpdateImage", ctx, errAliasID, imageBytes).Return(nil).Once()
readerWriter.On("UpdateAliases", ctx, studioID, i.Input.Aliases).Return(nil).Once()
readerWriter.On("UpdateAliases", ctx, errImageID, i.Input.Aliases).Return(nil).Maybe()
readerWriter.On("UpdateAliases", ctx, errAliasID, i.Input.Aliases).Return(updateTagAliasErr).Once()
err := i.PostImport(ctx, studioID)
assert.Nil(t, err)
@@ -180,9 +174,6 @@ func TestImporterPostImport(t *testing.T) {
err = i.PostImport(ctx, errImageID)
assert.NotNil(t, err)
err = i.PostImport(ctx, errAliasID)
assert.NotNil(t, err)
readerWriter.AssertExpectations(t)
}

View File

@@ -14,6 +14,12 @@ type Queryer interface {
Query(ctx context.Context, studioFilter *models.StudioFilterType, findFilter *models.FindFilterType) ([]*models.Studio, int, error)
}
type FinderQueryer interface {
Finder
Queryer
models.AliasLoader
}
func ByName(ctx context.Context, qb Queryer, name string) (*models.Studio, error) {
f := &models.StudioFilterType{
Name: &models.StringCriterionInput{

View File

@@ -2,11 +2,16 @@ package studio
import (
"context"
"errors"
"fmt"
"github.com/stashapp/stash/pkg/models"
)
var (
ErrStudioOwnAncestor = errors.New("studio cannot be an ancestor of itself")
)
type NameFinderCreator interface {
FindByName(ctx context.Context, name string, nocase bool) (*models.Studio, error)
Create(ctx context.Context, newStudio *models.Studio) error
@@ -69,3 +74,60 @@ func EnsureAliasesUnique(ctx context.Context, id int, aliases []string, qb Query
return nil
}
// Checks to make sure that:
// 1. The studio exists locally
// 2. The studio is not its own ancestor
// 3. The studio's aliases are unique
func ValidateModify(ctx context.Context, s models.StudioPartial, qb FinderQueryer) error {
existing, err := qb.Find(ctx, s.ID)
if err != nil {
return err
}
if existing == nil {
return fmt.Errorf("studio with id %d not found", s.ID)
}
newParentID := s.ParentID.Ptr()
if newParentID != nil {
if err := validateParent(ctx, s.ID, *newParentID, qb); err != nil {
return err
}
}
if s.Aliases != nil {
if err := existing.LoadAliases(ctx, qb); err != nil {
return err
}
effectiveAliases := s.Aliases.EffectiveValues(existing.Aliases.List())
if err := EnsureAliasesUnique(ctx, s.ID, effectiveAliases, qb); err != nil {
return err
}
}
return nil
}
func validateParent(ctx context.Context, studioID int, newParentID int, qb FinderQueryer) error {
if newParentID == studioID {
return ErrStudioOwnAncestor
}
// ensure there is no cyclic dependency
parentStudio, err := qb.Find(ctx, newParentID)
if err != nil {
return fmt.Errorf("error finding parent studio: %v", err)
}
if parentStudio == nil {
return fmt.Errorf("studio with id %d not found", newParentID)
}
if parentStudio.ParentID != nil {
return validateParent(ctx, studioID, *parentStudio.ParentID, qb)
}
return nil
}